diff options
Diffstat (limited to 'R2R/r2r/main/api')
-rwxr-xr-x | R2R/r2r/main/api/__init__.py | 0 | ||||
-rwxr-xr-x | R2R/r2r/main/api/client.py | 377 | ||||
-rwxr-xr-x | R2R/r2r/main/api/requests.py | 79 | ||||
-rwxr-xr-x | R2R/r2r/main/api/routes/__init__.py | 0 | ||||
-rwxr-xr-x | R2R/r2r/main/api/routes/base_router.py | 75 | ||||
-rwxr-xr-x | R2R/r2r/main/api/routes/ingestion.py | 42 | ||||
-rwxr-xr-x | R2R/r2r/main/api/routes/management.py | 101 | ||||
-rwxr-xr-x | R2R/r2r/main/api/routes/retrieval.py | 91 |
8 files changed, 765 insertions, 0 deletions
diff --git a/R2R/r2r/main/api/__init__.py b/R2R/r2r/main/api/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/main/api/__init__.py diff --git a/R2R/r2r/main/api/client.py b/R2R/r2r/main/api/client.py new file mode 100755 index 00000000..b0f5b966 --- /dev/null +++ b/R2R/r2r/main/api/client.py @@ -0,0 +1,377 @@ +import asyncio +import functools +import json +import os +import threading +import time +import uuid +from contextlib import ExitStack +from typing import Any, AsyncGenerator, Generator, Optional, Union + +import fire +import httpx +import nest_asyncio +import requests + +from .requests import ( + R2RAnalyticsRequest, + R2RDeleteRequest, + R2RDocumentChunksRequest, + R2RDocumentsOverviewRequest, + R2RIngestFilesRequest, + R2RLogsRequest, + R2RPrintRelationshipsRequest, + R2RRAGRequest, + R2RSearchRequest, + R2RUpdateFilesRequest, + R2RUpdatePromptRequest, + R2RUsersOverviewRequest, +) + +nest_asyncio.apply() + + +class R2RHTTPError(Exception): + def __init__(self, status_code, error_type, message): + self.status_code = status_code + self.error_type = error_type + self.message = message + super().__init__(f"[{status_code}] {error_type}: {message}") + + +def handle_request_error(response): + if response.status_code >= 400: + try: + error_content = response.json() + if isinstance(error_content, dict) and "detail" in error_content: + detail = error_content["detail"] + if isinstance(detail, dict): + message = detail.get("message", str(response.text)) + error_type = detail.get("error_type", "UnknownError") + else: + message = str(detail) + error_type = "HTTPException" + else: + message = str(error_content) + error_type = "UnknownError" + except json.JSONDecodeError: + message = response.text + error_type = "UnknownError" + + raise R2RHTTPError( + status_code=response.status_code, + error_type=error_type, + message=message, + ) + + +def monitor_request(func): + @functools.wraps(func) + def wrapper(*args, monitor=False, **kwargs): + if not monitor: + return func(*args, **kwargs) + + result = None + exception = None + + def run_func(): + nonlocal result, exception + try: + result = func(*args, **kwargs) + except Exception as e: + exception = e + + thread = threading.Thread(target=run_func) + thread.start() + + dots = [".", "..", "..."] + i = 0 + while thread.is_alive(): + print(f"\rRequesting{dots[i % 3]}", end="", flush=True) + i += 1 + time.sleep(0.5) + + thread.join() + + print("\r", end="", flush=True) + + if exception: + raise exception + return result + + return wrapper + + +class R2RClient: + def __init__(self, base_url: str, prefix: str = "/v1"): + self.base_url = base_url + self.prefix = prefix + + def _make_request(self, method, endpoint, **kwargs): + url = f"{self.base_url}{self.prefix}/{endpoint}" + response = requests.request(method, url, **kwargs) + handle_request_error(response) + return response.json() + + def health(self) -> dict: + return self._make_request("GET", "health") + + def update_prompt( + self, + name: str = "default_system", + template: Optional[str] = None, + input_types: Optional[dict] = None, + ) -> dict: + request = R2RUpdatePromptRequest( + name=name, template=template, input_types=input_types + ) + return self._make_request( + "POST", "update_prompt", json=json.loads(request.json()) + ) + + @monitor_request + def ingest_files( + self, + file_paths: list[str], + metadatas: Optional[list[dict]] = None, + document_ids: Optional[list[Union[uuid.UUID, str]]] = None, + versions: Optional[list[str]] = None, + ) -> dict: + all_file_paths = [] + + for path in file_paths: + if os.path.isdir(path): + for root, _, files in os.walk(path): + all_file_paths.extend( + os.path.join(root, file) for file in files + ) + else: + all_file_paths.append(path) + + files_to_upload = [ + ( + "files", + ( + os.path.basename(file), + open(file, "rb"), + "application/octet-stream", + ), + ) + for file in all_file_paths + ] + request = R2RIngestFilesRequest( + metadatas=metadatas, + document_ids=( + [str(ele) for ele in document_ids] if document_ids else None + ), + versions=versions, + ) + try: + return self._make_request( + "POST", + "ingest_files", + data={ + k: json.dumps(v) + for k, v in json.loads(request.json()).items() + }, + files=files_to_upload, + ) + finally: + for _, file_tuple in files_to_upload: + file_tuple[1].close() + + @monitor_request + def update_files( + self, + file_paths: list[str], + document_ids: list[str], + metadatas: Optional[list[dict]] = None, + ) -> dict: + request = R2RUpdateFilesRequest( + metadatas=metadatas, + document_ids=document_ids, + ) + with ExitStack() as stack: + return self._make_request( + "POST", + "update_files", + data={ + k: json.dumps(v) + for k, v in json.loads(request.json()).items() + }, + files=[ + ( + "files", + ( + path.split("/")[-1], + stack.enter_context(open(path, "rb")), + "application/octet-stream", + ), + ) + for path in file_paths + ], + ) + + def search( + self, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict[str, Any]] = {}, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg_search: bool = False, + kg_agent_generation_config: Optional[dict] = None, + ) -> dict: + request = R2RSearchRequest( + query=query, + vector_search_settings={ + "use_vector_search": use_vector_search, + "search_filters": search_filters or {}, + "search_limit": search_limit, + "do_hybrid_search": do_hybrid_search, + }, + kg_search_settings={ + "use_kg_search": use_kg_search, + "agent_generation_config": kg_agent_generation_config, + }, + ) + return self._make_request( + "POST", "search", json=json.loads(request.json()) + ) + + def rag( + self, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict[str, Any]] = {}, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg_search: bool = False, + kg_agent_generation_config: Optional[dict] = None, + rag_generation_config: Optional[dict] = None, + ) -> dict: + request = R2RRAGRequest( + query=query, + vector_search_settings={ + "use_vector_search": use_vector_search, + "search_filters": search_filters or {}, + "search_limit": search_limit, + "do_hybrid_search": do_hybrid_search, + }, + kg_search_settings={ + "use_kg_search": use_kg_search, + "agent_generation_config": kg_agent_generation_config, + }, + rag_generation_config=rag_generation_config, + ) + + if rag_generation_config and rag_generation_config.get( + "stream", False + ): + return self._stream_rag_sync(request) + else: + return self._make_request( + "POST", "rag", json=json.loads(request.json()) + ) + + async def _stream_rag( + self, rag_request: R2RRAGRequest + ) -> AsyncGenerator[str, None]: + url = f"{self.base_url}{self.prefix}/rag" + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", url, json=json.loads(rag_request.json()) + ) as response: + handle_request_error(response) + async for chunk in response.aiter_text(): + yield chunk + + def _stream_rag_sync( + self, rag_request: R2RRAGRequest + ) -> Generator[str, None, None]: + async def run_async_generator(): + async for chunk in self._stream_rag(rag_request): + yield chunk + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async_gen = run_async_generator() + + try: + while True: + chunk = loop.run_until_complete(async_gen.__anext__()) + yield chunk + except StopAsyncIteration: + pass + finally: + loop.close() + + def delete( + self, keys: list[str], values: list[Union[bool, int, str]] + ) -> dict: + request = R2RDeleteRequest(keys=keys, values=values) + return self._make_request( + "DELETE", "delete", json=json.loads(request.json()) + ) + + def logs(self, log_type_filter: Optional[str] = None) -> dict: + request = R2RLogsRequest(log_type_filter=log_type_filter) + return self._make_request( + "GET", "logs", json=json.loads(request.json()) + ) + + def app_settings(self) -> dict: + return self._make_request("GET", "app_settings") + + def analytics(self, filter_criteria: dict, analysis_types: dict) -> dict: + request = R2RAnalyticsRequest( + filter_criteria=filter_criteria, analysis_types=analysis_types + ) + return self._make_request( + "GET", "analytics", json=json.loads(request.json()) + ) + + def users_overview( + self, user_ids: Optional[list[uuid.UUID]] = None + ) -> dict: + request = R2RUsersOverviewRequest(user_ids=user_ids) + return self._make_request( + "GET", "users_overview", json=json.loads(request.json()) + ) + + def documents_overview( + self, + document_ids: Optional[list[str]] = None, + user_ids: Optional[list[str]] = None, + ) -> dict: + request = R2RDocumentsOverviewRequest( + document_ids=( + [uuid.UUID(did) for did in document_ids] + if document_ids + else None + ), + user_ids=( + [uuid.UUID(uid) for uid in user_ids] if user_ids else None + ), + ) + return self._make_request( + "GET", "documents_overview", json=json.loads(request.json()) + ) + + def document_chunks(self, document_id: str) -> dict: + request = R2RDocumentChunksRequest(document_id=document_id) + return self._make_request( + "GET", "document_chunks", json=json.loads(request.json()) + ) + + def inspect_knowledge_graph(self, limit: int = 100) -> str: + request = R2RPrintRelationshipsRequest(limit=limit) + return self._make_request( + "POST", "inspect_knowledge_graph", json=json.loads(request.json()) + ) + + +if __name__ == "__main__": + client = R2RClient(base_url="http://localhost:8000") + fire.Fire(client) diff --git a/R2R/r2r/main/api/requests.py b/R2R/r2r/main/api/requests.py new file mode 100755 index 00000000..5c63ab82 --- /dev/null +++ b/R2R/r2r/main/api/requests.py @@ -0,0 +1,79 @@ +import uuid +from typing import Optional, Union + +from pydantic import BaseModel + +from r2r.base import AnalysisTypes, FilterCriteria + + +class R2RUpdatePromptRequest(BaseModel): + name: str + template: Optional[str] = None + input_types: Optional[dict[str, str]] = {} + + +class R2RIngestFilesRequest(BaseModel): + document_ids: Optional[list[uuid.UUID]] = None + metadatas: Optional[list[dict]] = None + versions: Optional[list[str]] = None + + +class R2RUpdateFilesRequest(BaseModel): + metadatas: Optional[list[dict]] = None + document_ids: Optional[list[uuid.UUID]] = None + + +class R2RSearchRequest(BaseModel): + query: str + vector_search_settings: Optional[dict] = None + kg_search_settings: Optional[dict] = None + + +class R2RRAGRequest(BaseModel): + query: str + vector_search_settings: Optional[dict] = None + kg_search_settings: Optional[dict] = None + rag_generation_config: Optional[dict] = None + + +class R2REvalRequest(BaseModel): + query: str + context: str + completion: str + + +class R2RDeleteRequest(BaseModel): + keys: list[str] + values: list[Union[bool, int, str]] + + +class R2RAnalyticsRequest(BaseModel): + filter_criteria: FilterCriteria + analysis_types: AnalysisTypes + + +class R2RUsersOverviewRequest(BaseModel): + user_ids: Optional[list[uuid.UUID]] + + +class R2RDocumentsOverviewRequest(BaseModel): + document_ids: Optional[list[uuid.UUID]] + user_ids: Optional[list[uuid.UUID]] + + +class R2RDocumentChunksRequest(BaseModel): + document_id: uuid.UUID + + +class R2RLogsRequest(BaseModel): + log_type_filter: Optional[str] = (None,) + max_runs_requested: int = 100 + + +class R2RPrintRelationshipsRequest(BaseModel): + limit: int = 100 + + +class R2RExtractionRequest(BaseModel): + entity_types: list[str] + relations: list[str] diff --git a/R2R/r2r/main/api/routes/__init__.py b/R2R/r2r/main/api/routes/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/main/api/routes/__init__.py diff --git a/R2R/r2r/main/api/routes/base_router.py b/R2R/r2r/main/api/routes/base_router.py new file mode 100755 index 00000000..d06a9935 --- /dev/null +++ b/R2R/r2r/main/api/routes/base_router.py @@ -0,0 +1,75 @@ +import functools +import logging + +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse + +from r2r.base import R2RException, manage_run + +logger = logging.getLogger(__name__) + + +class BaseRouter: + def __init__(self, engine): + self.engine = engine + self.router = APIRouter() + + def base_endpoint(self, func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + async with manage_run( + self.engine.run_manager, func.__name__ + ) as run_id: + try: + results = await func(*args, **kwargs) + if isinstance(results, StreamingResponse): + return results + + return {"results": results} + except R2RException as re: + raise HTTPException( + status_code=re.status_code, + detail={ + "message": re.message, + "error_type": type(re).__name__, + }, + ) + except Exception as e: + # Get the pipeline name based on the function name + pipeline_name = f"{func.__name__.split('_')[0]}_pipeline" + + # Safely get the pipeline object and its type + pipeline = getattr( + self.engine.pipelines, pipeline_name, None + ) + pipeline_type = getattr( + pipeline, "pipeline_type", "unknown" + ) + + await self.engine.logging_connection.log( + log_id=run_id, + key="pipeline_type", + value=pipeline_type, + is_info_log=True, + ) + await self.engine.logging_connection.log( + log_id=run_id, + key="error", + value=str(e), + is_info_log=False, + ) + logger.error(f"{func.__name__}() - \n\n{str(e)})") + raise HTTPException( + status_code=500, + detail={ + "message": f"An error occurred during {func.__name__}", + "error": str(e), + "error_type": type(e).__name__, + }, + ) from e + + return wrapper + + @classmethod + def build_router(cls, engine): + return cls(engine).router diff --git a/R2R/r2r/main/api/routes/ingestion.py b/R2R/r2r/main/api/routes/ingestion.py new file mode 100755 index 00000000..be583602 --- /dev/null +++ b/R2R/r2r/main/api/routes/ingestion.py @@ -0,0 +1,42 @@ +from fastapi import Depends, File, UploadFile + +from ...engine import R2REngine +from ...services.ingestion_service import IngestionService +from ..requests import R2RIngestFilesRequest, R2RUpdateFilesRequest +from .base_router import BaseRouter + + +class IngestionRouter(BaseRouter): + def __init__(self, engine: R2REngine): + super().__init__(engine) + self.setup_routes() + + def setup_routes(self): + @self.router.post("/ingest_files") + @self.base_endpoint + async def ingest_files_app( + files: list[UploadFile] = File(...), + request: R2RIngestFilesRequest = Depends( + IngestionService.parse_ingest_files_form_data + ), + ): + return await self.engine.aingest_files( + files=files, + metadatas=request.metadatas, + document_ids=request.document_ids, + versions=request.versions, + ) + + @self.router.post("/update_files") + @self.base_endpoint + async def update_files_app( + files: list[UploadFile] = File(...), + request: R2RUpdateFilesRequest = Depends( + IngestionService.parse_update_files_form_data + ), + ): + return await self.engine.aupdate_files( + files=files, + metadatas=request.metadatas, + document_ids=request.document_ids, + ) diff --git a/R2R/r2r/main/api/routes/management.py b/R2R/r2r/main/api/routes/management.py new file mode 100755 index 00000000..921fb534 --- /dev/null +++ b/R2R/r2r/main/api/routes/management.py @@ -0,0 +1,101 @@ +from ...engine import R2REngine +from ..requests import ( + R2RAnalyticsRequest, + R2RDeleteRequest, + R2RDocumentChunksRequest, + R2RDocumentsOverviewRequest, + R2RLogsRequest, + R2RPrintRelationshipsRequest, + R2RUpdatePromptRequest, + R2RUsersOverviewRequest, +) +from .base_router import BaseRouter + + +class ManagementRouter(BaseRouter): + def __init__(self, engine: R2REngine): + super().__init__(engine) + self.setup_routes() + + def setup_routes(self): + @self.router.get("/health") + async def health_check(): + return {"response": "ok"} + + @self.router.post("/update_prompt") + @self.base_endpoint + async def update_prompt_app(request: R2RUpdatePromptRequest): + return await self.engine.aupdate_prompt( + request.name, request.template, request.input_types + ) + + @self.router.post("/logs") + @self.router.get("/logs") + @self.base_endpoint + async def get_logs_app(request: R2RLogsRequest): + return await self.engine.alogs( + log_type_filter=request.log_type_filter, + max_runs_requested=request.max_runs_requested, + ) + + @self.router.post("/analytics") + @self.router.get("/analytics") + @self.base_endpoint + async def get_analytics_app(request: R2RAnalyticsRequest): + return await self.engine.aanalytics( + filter_criteria=request.filter_criteria, + analysis_types=request.analysis_types, + ) + + @self.router.post("/users_overview") + @self.router.get("/users_overview") + @self.base_endpoint + async def get_users_overview_app(request: R2RUsersOverviewRequest): + return await self.engine.ausers_overview(user_ids=request.user_ids) + + @self.router.delete("/delete") + @self.base_endpoint + async def delete_app(request: R2RDeleteRequest): + return await self.engine.adelete( + keys=request.keys, values=request.values + ) + + @self.router.post("/documents_overview") + @self.router.get("/documents_overview") + @self.base_endpoint + async def get_documents_overview_app( + request: R2RDocumentsOverviewRequest, + ): + return await self.engine.adocuments_overview( + document_ids=request.document_ids, user_ids=request.user_ids + ) + + @self.router.post("/document_chunks") + @self.router.get("/document_chunks") + @self.base_endpoint + async def get_document_chunks_app(request: R2RDocumentChunksRequest): + return await self.engine.adocument_chunks(request.document_id) + + @self.router.post("/inspect_knowledge_graph") + @self.router.get("/inspect_knowledge_graph") + @self.base_endpoint + async def inspect_knowledge_graph( + request: R2RPrintRelationshipsRequest, + ): + return await self.engine.inspect_knowledge_graph( + limit=request.limit + ) + + @self.router.get("/app_settings") + @self.base_endpoint + async def get_app_settings_app(): + return await self.engine.aapp_settings() + + @self.router.get("/openapi_spec") + @self.base_endpoint + def get_openapi_spec_app(): + return self.engine.openapi_spec() + + +def create_management_router(engine: R2REngine): + return ManagementRouter(engine).router diff --git a/R2R/r2r/main/api/routes/retrieval.py b/R2R/r2r/main/api/routes/retrieval.py new file mode 100755 index 00000000..b2d352aa --- /dev/null +++ b/R2R/r2r/main/api/routes/retrieval.py @@ -0,0 +1,91 @@ +from fastapi.responses import StreamingResponse + +from r2r.base import GenerationConfig, KGSearchSettings, VectorSearchSettings + +from ...engine import R2REngine +from ..requests import R2REvalRequest, R2RRAGRequest, R2RSearchRequest +from .base_router import BaseRouter + + +class RetrievalRouter(BaseRouter): + def __init__(self, engine: R2REngine): + super().__init__(engine) + self.setup_routes() + + def setup_routes(self): + @self.router.post("/search") + @self.base_endpoint + async def search_app(request: R2RSearchRequest): + if "agent_generation_config" in request.kg_search_settings: + request.kg_search_settings["agent_generation_config"] = ( + GenerationConfig( + **request.kg_search_settings["agent_generation_config"] + or {} + ) + ) + + results = await self.engine.asearch( + query=request.query, + vector_search_settings=VectorSearchSettings( + **(request.vector_search_settings or {}) + ), + kg_search_settings=KGSearchSettings( + **(request.kg_search_settings or {}) + ), + ) + return results + + @self.router.post("/rag") + @self.base_endpoint + async def rag_app(request: R2RRAGRequest): + if "agent_generation_config" in request.kg_search_settings: + request.kg_search_settings["agent_generation_config"] = ( + GenerationConfig( + **( + request.kg_search_settings[ + "agent_generation_config" + ] + or {} + ) + ) + ) + response = await self.engine.arag( + query=request.query, + vector_search_settings=VectorSearchSettings( + **(request.vector_search_settings or {}) + ), + kg_search_settings=KGSearchSettings( + **(request.kg_search_settings or {}) + ), + rag_generation_config=GenerationConfig( + **(request.rag_generation_config or {}) + ), + ) + if ( + request.rag_generation_config + and request.rag_generation_config.get("stream", False) + ): + + async def stream_generator(): + async for chunk in response: + yield chunk + + return StreamingResponse( + stream_generator(), media_type="application/json" + ) + else: + return response + + @self.router.post("/evaluate") + @self.base_endpoint + async def evaluate_app(request: R2REvalRequest): + results = await self.engine.aevaluate( + query=request.query, + context=request.context, + completion=request.completion, + ) + return results + + +def create_retrieval_router(engine: R2REngine): + return RetrievalRouter(engine).router |