diff options
Diffstat (limited to 'R2R/r2r/main')
25 files changed, 3641 insertions, 0 deletions
diff --git a/R2R/r2r/main/__init__.py b/R2R/r2r/main/__init__.py new file mode 100755 index 00000000..55a828d6 --- /dev/null +++ b/R2R/r2r/main/__init__.py @@ -0,0 +1,54 @@ +from .abstractions import R2RPipelines, R2RProviders +from .api.client import R2RClient +from .api.requests import ( + R2RAnalyticsRequest, + R2RDeleteRequest, + R2RDocumentChunksRequest, + R2RDocumentsOverviewRequest, + R2REvalRequest, + R2RIngestFilesRequest, + R2RRAGRequest, + R2RSearchRequest, + R2RUpdateFilesRequest, + R2RUpdatePromptRequest, + R2RUsersOverviewRequest, +) +from .app import R2RApp +from .assembly.builder import R2RBuilder +from .assembly.config import R2RConfig +from .assembly.factory import ( + R2RPipeFactory, + R2RPipelineFactory, + R2RProviderFactory, +) +from .assembly.factory_extensions import R2RPipeFactoryWithMultiSearch +from .engine import R2REngine +from .execution import R2RExecutionWrapper +from .r2r import R2R + +__all__ = [ + "R2R", + "R2RPipelines", + "R2RProviders", + "R2RUpdatePromptRequest", + "R2RIngestFilesRequest", + "R2RUpdateFilesRequest", + "R2RSearchRequest", + "R2RRAGRequest", + "R2REvalRequest", + "R2RDeleteRequest", + "R2RAnalyticsRequest", + "R2RUsersOverviewRequest", + "R2RDocumentsOverviewRequest", + "R2RDocumentChunksRequest", + "R2REngine", + "R2RExecutionWrapper", + "R2RConfig", + "R2RClient", + "R2RPipeFactory", + "R2RPipelineFactory", + "R2RProviderFactory", + "R2RPipeFactoryWithMultiSearch", + "R2RBuilder", + "R2RApp", +] diff --git a/R2R/r2r/main/abstractions.py b/R2R/r2r/main/abstractions.py new file mode 100755 index 00000000..3622b22d --- /dev/null +++ b/R2R/r2r/main/abstractions.py @@ -0,0 +1,58 @@ +from typing import Optional + +from pydantic import BaseModel + +from r2r.base import ( + AsyncPipe, + EmbeddingProvider, + EvalProvider, + KGProvider, + LLMProvider, + PromptProvider, + VectorDBProvider, +) +from r2r.pipelines import ( + EvalPipeline, + IngestionPipeline, + RAGPipeline, + SearchPipeline, +) + + +class R2RProviders(BaseModel): + vector_db: Optional[VectorDBProvider] + embedding: Optional[EmbeddingProvider] + llm: Optional[LLMProvider] + prompt: Optional[PromptProvider] + eval: Optional[EvalProvider] + kg: Optional[KGProvider] + + class Config: + arbitrary_types_allowed = True + + +class R2RPipes(BaseModel): + parsing_pipe: Optional[AsyncPipe] + embedding_pipe: Optional[AsyncPipe] + vector_storage_pipe: Optional[AsyncPipe] + vector_search_pipe: Optional[AsyncPipe] + rag_pipe: Optional[AsyncPipe] + streaming_rag_pipe: Optional[AsyncPipe] + eval_pipe: Optional[AsyncPipe] + kg_pipe: Optional[AsyncPipe] + kg_storage_pipe: Optional[AsyncPipe] + kg_agent_search_pipe: Optional[AsyncPipe] + + class Config: + arbitrary_types_allowed = True + + +class R2RPipelines(BaseModel): + eval_pipeline: EvalPipeline + ingestion_pipeline: IngestionPipeline + search_pipeline: SearchPipeline + rag_pipeline: RAGPipeline + streaming_rag_pipeline: RAGPipeline + + class Config: + arbitrary_types_allowed = True diff --git a/R2R/r2r/main/api/__init__.py b/R2R/r2r/main/api/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/main/api/__init__.py diff --git a/R2R/r2r/main/api/client.py b/R2R/r2r/main/api/client.py new file mode 100755 index 00000000..b0f5b966 --- /dev/null +++ b/R2R/r2r/main/api/client.py @@ -0,0 +1,377 @@ +import asyncio +import functools +import json +import os +import threading +import time +import uuid +from contextlib import ExitStack +from typing import Any, AsyncGenerator, Generator, Optional, Union + +import fire +import httpx +import nest_asyncio +import requests + +from .requests import ( + R2RAnalyticsRequest, + R2RDeleteRequest, + R2RDocumentChunksRequest, + R2RDocumentsOverviewRequest, + R2RIngestFilesRequest, + R2RLogsRequest, + R2RPrintRelationshipsRequest, + R2RRAGRequest, + R2RSearchRequest, + R2RUpdateFilesRequest, + R2RUpdatePromptRequest, + R2RUsersOverviewRequest, +) + +nest_asyncio.apply() + + +class R2RHTTPError(Exception): + def __init__(self, status_code, error_type, message): + self.status_code = status_code + self.error_type = error_type + self.message = message + super().__init__(f"[{status_code}] {error_type}: {message}") + + +def handle_request_error(response): + if response.status_code >= 400: + try: + error_content = response.json() + if isinstance(error_content, dict) and "detail" in error_content: + detail = error_content["detail"] + if isinstance(detail, dict): + message = detail.get("message", str(response.text)) + error_type = detail.get("error_type", "UnknownError") + else: + message = str(detail) + error_type = "HTTPException" + else: + message = str(error_content) + error_type = "UnknownError" + except json.JSONDecodeError: + message = response.text + error_type = "UnknownError" + + raise R2RHTTPError( + status_code=response.status_code, + error_type=error_type, + message=message, + ) + + +def monitor_request(func): + @functools.wraps(func) + def wrapper(*args, monitor=False, **kwargs): + if not monitor: + return func(*args, **kwargs) + + result = None + exception = None + + def run_func(): + nonlocal result, exception + try: + result = func(*args, **kwargs) + except Exception as e: + exception = e + + thread = threading.Thread(target=run_func) + thread.start() + + dots = [".", "..", "..."] + i = 0 + while thread.is_alive(): + print(f"\rRequesting{dots[i % 3]}", end="", flush=True) + i += 1 + time.sleep(0.5) + + thread.join() + + print("\r", end="", flush=True) + + if exception: + raise exception + return result + + return wrapper + + +class R2RClient: + def __init__(self, base_url: str, prefix: str = "/v1"): + self.base_url = base_url + self.prefix = prefix + + def _make_request(self, method, endpoint, **kwargs): + url = f"{self.base_url}{self.prefix}/{endpoint}" + response = requests.request(method, url, **kwargs) + handle_request_error(response) + return response.json() + + def health(self) -> dict: + return self._make_request("GET", "health") + + def update_prompt( + self, + name: str = "default_system", + template: Optional[str] = None, + input_types: Optional[dict] = None, + ) -> dict: + request = R2RUpdatePromptRequest( + name=name, template=template, input_types=input_types + ) + return self._make_request( + "POST", "update_prompt", json=json.loads(request.json()) + ) + + @monitor_request + def ingest_files( + self, + file_paths: list[str], + metadatas: Optional[list[dict]] = None, + document_ids: Optional[list[Union[uuid.UUID, str]]] = None, + versions: Optional[list[str]] = None, + ) -> dict: + all_file_paths = [] + + for path in file_paths: + if os.path.isdir(path): + for root, _, files in os.walk(path): + all_file_paths.extend( + os.path.join(root, file) for file in files + ) + else: + all_file_paths.append(path) + + files_to_upload = [ + ( + "files", + ( + os.path.basename(file), + open(file, "rb"), + "application/octet-stream", + ), + ) + for file in all_file_paths + ] + request = R2RIngestFilesRequest( + metadatas=metadatas, + document_ids=( + [str(ele) for ele in document_ids] if document_ids else None + ), + versions=versions, + ) + try: + return self._make_request( + "POST", + "ingest_files", + data={ + k: json.dumps(v) + for k, v in json.loads(request.json()).items() + }, + files=files_to_upload, + ) + finally: + for _, file_tuple in files_to_upload: + file_tuple[1].close() + + @monitor_request + def update_files( + self, + file_paths: list[str], + document_ids: list[str], + metadatas: Optional[list[dict]] = None, + ) -> dict: + request = R2RUpdateFilesRequest( + metadatas=metadatas, + document_ids=document_ids, + ) + with ExitStack() as stack: + return self._make_request( + "POST", + "update_files", + data={ + k: json.dumps(v) + for k, v in json.loads(request.json()).items() + }, + files=[ + ( + "files", + ( + path.split("/")[-1], + stack.enter_context(open(path, "rb")), + "application/octet-stream", + ), + ) + for path in file_paths + ], + ) + + def search( + self, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict[str, Any]] = {}, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg_search: bool = False, + kg_agent_generation_config: Optional[dict] = None, + ) -> dict: + request = R2RSearchRequest( + query=query, + vector_search_settings={ + "use_vector_search": use_vector_search, + "search_filters": search_filters or {}, + "search_limit": search_limit, + "do_hybrid_search": do_hybrid_search, + }, + kg_search_settings={ + "use_kg_search": use_kg_search, + "agent_generation_config": kg_agent_generation_config, + }, + ) + return self._make_request( + "POST", "search", json=json.loads(request.json()) + ) + + def rag( + self, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict[str, Any]] = {}, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg_search: bool = False, + kg_agent_generation_config: Optional[dict] = None, + rag_generation_config: Optional[dict] = None, + ) -> dict: + request = R2RRAGRequest( + query=query, + vector_search_settings={ + "use_vector_search": use_vector_search, + "search_filters": search_filters or {}, + "search_limit": search_limit, + "do_hybrid_search": do_hybrid_search, + }, + kg_search_settings={ + "use_kg_search": use_kg_search, + "agent_generation_config": kg_agent_generation_config, + }, + rag_generation_config=rag_generation_config, + ) + + if rag_generation_config and rag_generation_config.get( + "stream", False + ): + return self._stream_rag_sync(request) + else: + return self._make_request( + "POST", "rag", json=json.loads(request.json()) + ) + + async def _stream_rag( + self, rag_request: R2RRAGRequest + ) -> AsyncGenerator[str, None]: + url = f"{self.base_url}{self.prefix}/rag" + async with httpx.AsyncClient() as client: + async with client.stream( + "POST", url, json=json.loads(rag_request.json()) + ) as response: + handle_request_error(response) + async for chunk in response.aiter_text(): + yield chunk + + def _stream_rag_sync( + self, rag_request: R2RRAGRequest + ) -> Generator[str, None, None]: + async def run_async_generator(): + async for chunk in self._stream_rag(rag_request): + yield chunk + + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + async_gen = run_async_generator() + + try: + while True: + chunk = loop.run_until_complete(async_gen.__anext__()) + yield chunk + except StopAsyncIteration: + pass + finally: + loop.close() + + def delete( + self, keys: list[str], values: list[Union[bool, int, str]] + ) -> dict: + request = R2RDeleteRequest(keys=keys, values=values) + return self._make_request( + "DELETE", "delete", json=json.loads(request.json()) + ) + + def logs(self, log_type_filter: Optional[str] = None) -> dict: + request = R2RLogsRequest(log_type_filter=log_type_filter) + return self._make_request( + "GET", "logs", json=json.loads(request.json()) + ) + + def app_settings(self) -> dict: + return self._make_request("GET", "app_settings") + + def analytics(self, filter_criteria: dict, analysis_types: dict) -> dict: + request = R2RAnalyticsRequest( + filter_criteria=filter_criteria, analysis_types=analysis_types + ) + return self._make_request( + "GET", "analytics", json=json.loads(request.json()) + ) + + def users_overview( + self, user_ids: Optional[list[uuid.UUID]] = None + ) -> dict: + request = R2RUsersOverviewRequest(user_ids=user_ids) + return self._make_request( + "GET", "users_overview", json=json.loads(request.json()) + ) + + def documents_overview( + self, + document_ids: Optional[list[str]] = None, + user_ids: Optional[list[str]] = None, + ) -> dict: + request = R2RDocumentsOverviewRequest( + document_ids=( + [uuid.UUID(did) for did in document_ids] + if document_ids + else None + ), + user_ids=( + [uuid.UUID(uid) for uid in user_ids] if user_ids else None + ), + ) + return self._make_request( + "GET", "documents_overview", json=json.loads(request.json()) + ) + + def document_chunks(self, document_id: str) -> dict: + request = R2RDocumentChunksRequest(document_id=document_id) + return self._make_request( + "GET", "document_chunks", json=json.loads(request.json()) + ) + + def inspect_knowledge_graph(self, limit: int = 100) -> str: + request = R2RPrintRelationshipsRequest(limit=limit) + return self._make_request( + "POST", "inspect_knowledge_graph", json=json.loads(request.json()) + ) + + +if __name__ == "__main__": + client = R2RClient(base_url="http://localhost:8000") + fire.Fire(client) diff --git a/R2R/r2r/main/api/requests.py b/R2R/r2r/main/api/requests.py new file mode 100755 index 00000000..5c63ab82 --- /dev/null +++ b/R2R/r2r/main/api/requests.py @@ -0,0 +1,79 @@ +import uuid +from typing import Optional, Union + +from pydantic import BaseModel + +from r2r.base import AnalysisTypes, FilterCriteria + + +class R2RUpdatePromptRequest(BaseModel): + name: str + template: Optional[str] = None + input_types: Optional[dict[str, str]] = {} + + +class R2RIngestFilesRequest(BaseModel): + document_ids: Optional[list[uuid.UUID]] = None + metadatas: Optional[list[dict]] = None + versions: Optional[list[str]] = None + + +class R2RUpdateFilesRequest(BaseModel): + metadatas: Optional[list[dict]] = None + document_ids: Optional[list[uuid.UUID]] = None + + +class R2RSearchRequest(BaseModel): + query: str + vector_search_settings: Optional[dict] = None + kg_search_settings: Optional[dict] = None + + +class R2RRAGRequest(BaseModel): + query: str + vector_search_settings: Optional[dict] = None + kg_search_settings: Optional[dict] = None + rag_generation_config: Optional[dict] = None + + +class R2REvalRequest(BaseModel): + query: str + context: str + completion: str + + +class R2RDeleteRequest(BaseModel): + keys: list[str] + values: list[Union[bool, int, str]] + + +class R2RAnalyticsRequest(BaseModel): + filter_criteria: FilterCriteria + analysis_types: AnalysisTypes + + +class R2RUsersOverviewRequest(BaseModel): + user_ids: Optional[list[uuid.UUID]] + + +class R2RDocumentsOverviewRequest(BaseModel): + document_ids: Optional[list[uuid.UUID]] + user_ids: Optional[list[uuid.UUID]] + + +class R2RDocumentChunksRequest(BaseModel): + document_id: uuid.UUID + + +class R2RLogsRequest(BaseModel): + log_type_filter: Optional[str] = (None,) + max_runs_requested: int = 100 + + +class R2RPrintRelationshipsRequest(BaseModel): + limit: int = 100 + + +class R2RExtractionRequest(BaseModel): + entity_types: list[str] + relations: list[str] diff --git a/R2R/r2r/main/api/routes/__init__.py b/R2R/r2r/main/api/routes/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/main/api/routes/__init__.py diff --git a/R2R/r2r/main/api/routes/base_router.py b/R2R/r2r/main/api/routes/base_router.py new file mode 100755 index 00000000..d06a9935 --- /dev/null +++ b/R2R/r2r/main/api/routes/base_router.py @@ -0,0 +1,75 @@ +import functools +import logging + +from fastapi import APIRouter, HTTPException +from fastapi.responses import StreamingResponse + +from r2r.base import R2RException, manage_run + +logger = logging.getLogger(__name__) + + +class BaseRouter: + def __init__(self, engine): + self.engine = engine + self.router = APIRouter() + + def base_endpoint(self, func): + @functools.wraps(func) + async def wrapper(*args, **kwargs): + async with manage_run( + self.engine.run_manager, func.__name__ + ) as run_id: + try: + results = await func(*args, **kwargs) + if isinstance(results, StreamingResponse): + return results + + return {"results": results} + except R2RException as re: + raise HTTPException( + status_code=re.status_code, + detail={ + "message": re.message, + "error_type": type(re).__name__, + }, + ) + except Exception as e: + # Get the pipeline name based on the function name + pipeline_name = f"{func.__name__.split('_')[0]}_pipeline" + + # Safely get the pipeline object and its type + pipeline = getattr( + self.engine.pipelines, pipeline_name, None + ) + pipeline_type = getattr( + pipeline, "pipeline_type", "unknown" + ) + + await self.engine.logging_connection.log( + log_id=run_id, + key="pipeline_type", + value=pipeline_type, + is_info_log=True, + ) + await self.engine.logging_connection.log( + log_id=run_id, + key="error", + value=str(e), + is_info_log=False, + ) + logger.error(f"{func.__name__}() - \n\n{str(e)})") + raise HTTPException( + status_code=500, + detail={ + "message": f"An error occurred during {func.__name__}", + "error": str(e), + "error_type": type(e).__name__, + }, + ) from e + + return wrapper + + @classmethod + def build_router(cls, engine): + return cls(engine).router diff --git a/R2R/r2r/main/api/routes/ingestion.py b/R2R/r2r/main/api/routes/ingestion.py new file mode 100755 index 00000000..be583602 --- /dev/null +++ b/R2R/r2r/main/api/routes/ingestion.py @@ -0,0 +1,42 @@ +from fastapi import Depends, File, UploadFile + +from ...engine import R2REngine +from ...services.ingestion_service import IngestionService +from ..requests import R2RIngestFilesRequest, R2RUpdateFilesRequest +from .base_router import BaseRouter + + +class IngestionRouter(BaseRouter): + def __init__(self, engine: R2REngine): + super().__init__(engine) + self.setup_routes() + + def setup_routes(self): + @self.router.post("/ingest_files") + @self.base_endpoint + async def ingest_files_app( + files: list[UploadFile] = File(...), + request: R2RIngestFilesRequest = Depends( + IngestionService.parse_ingest_files_form_data + ), + ): + return await self.engine.aingest_files( + files=files, + metadatas=request.metadatas, + document_ids=request.document_ids, + versions=request.versions, + ) + + @self.router.post("/update_files") + @self.base_endpoint + async def update_files_app( + files: list[UploadFile] = File(...), + request: R2RUpdateFilesRequest = Depends( + IngestionService.parse_update_files_form_data + ), + ): + return await self.engine.aupdate_files( + files=files, + metadatas=request.metadatas, + document_ids=request.document_ids, + ) diff --git a/R2R/r2r/main/api/routes/management.py b/R2R/r2r/main/api/routes/management.py new file mode 100755 index 00000000..921fb534 --- /dev/null +++ b/R2R/r2r/main/api/routes/management.py @@ -0,0 +1,101 @@ +from ...engine import R2REngine +from ..requests import ( + R2RAnalyticsRequest, + R2RDeleteRequest, + R2RDocumentChunksRequest, + R2RDocumentsOverviewRequest, + R2RLogsRequest, + R2RPrintRelationshipsRequest, + R2RUpdatePromptRequest, + R2RUsersOverviewRequest, +) +from .base_router import BaseRouter + + +class ManagementRouter(BaseRouter): + def __init__(self, engine: R2REngine): + super().__init__(engine) + self.setup_routes() + + def setup_routes(self): + @self.router.get("/health") + async def health_check(): + return {"response": "ok"} + + @self.router.post("/update_prompt") + @self.base_endpoint + async def update_prompt_app(request: R2RUpdatePromptRequest): + return await self.engine.aupdate_prompt( + request.name, request.template, request.input_types + ) + + @self.router.post("/logs") + @self.router.get("/logs") + @self.base_endpoint + async def get_logs_app(request: R2RLogsRequest): + return await self.engine.alogs( + log_type_filter=request.log_type_filter, + max_runs_requested=request.max_runs_requested, + ) + + @self.router.post("/analytics") + @self.router.get("/analytics") + @self.base_endpoint + async def get_analytics_app(request: R2RAnalyticsRequest): + return await self.engine.aanalytics( + filter_criteria=request.filter_criteria, + analysis_types=request.analysis_types, + ) + + @self.router.post("/users_overview") + @self.router.get("/users_overview") + @self.base_endpoint + async def get_users_overview_app(request: R2RUsersOverviewRequest): + return await self.engine.ausers_overview(user_ids=request.user_ids) + + @self.router.delete("/delete") + @self.base_endpoint + async def delete_app(request: R2RDeleteRequest): + return await self.engine.adelete( + keys=request.keys, values=request.values + ) + + @self.router.post("/documents_overview") + @self.router.get("/documents_overview") + @self.base_endpoint + async def get_documents_overview_app( + request: R2RDocumentsOverviewRequest, + ): + return await self.engine.adocuments_overview( + document_ids=request.document_ids, user_ids=request.user_ids + ) + + @self.router.post("/document_chunks") + @self.router.get("/document_chunks") + @self.base_endpoint + async def get_document_chunks_app(request: R2RDocumentChunksRequest): + return await self.engine.adocument_chunks(request.document_id) + + @self.router.post("/inspect_knowledge_graph") + @self.router.get("/inspect_knowledge_graph") + @self.base_endpoint + async def inspect_knowledge_graph( + request: R2RPrintRelationshipsRequest, + ): + return await self.engine.inspect_knowledge_graph( + limit=request.limit + ) + + @self.router.get("/app_settings") + @self.base_endpoint + async def get_app_settings_app(): + return await self.engine.aapp_settings() + + @self.router.get("/openapi_spec") + @self.base_endpoint + def get_openapi_spec_app(): + return self.engine.openapi_spec() + + +def create_management_router(engine: R2REngine): + return ManagementRouter(engine).router diff --git a/R2R/r2r/main/api/routes/retrieval.py b/R2R/r2r/main/api/routes/retrieval.py new file mode 100755 index 00000000..b2d352aa --- /dev/null +++ b/R2R/r2r/main/api/routes/retrieval.py @@ -0,0 +1,91 @@ +from fastapi.responses import StreamingResponse + +from r2r.base import GenerationConfig, KGSearchSettings, VectorSearchSettings + +from ...engine import R2REngine +from ..requests import R2REvalRequest, R2RRAGRequest, R2RSearchRequest +from .base_router import BaseRouter + + +class RetrievalRouter(BaseRouter): + def __init__(self, engine: R2REngine): + super().__init__(engine) + self.setup_routes() + + def setup_routes(self): + @self.router.post("/search") + @self.base_endpoint + async def search_app(request: R2RSearchRequest): + if "agent_generation_config" in request.kg_search_settings: + request.kg_search_settings["agent_generation_config"] = ( + GenerationConfig( + **request.kg_search_settings["agent_generation_config"] + or {} + ) + ) + + results = await self.engine.asearch( + query=request.query, + vector_search_settings=VectorSearchSettings( + **(request.vector_search_settings or {}) + ), + kg_search_settings=KGSearchSettings( + **(request.kg_search_settings or {}) + ), + ) + return results + + @self.router.post("/rag") + @self.base_endpoint + async def rag_app(request: R2RRAGRequest): + if "agent_generation_config" in request.kg_search_settings: + request.kg_search_settings["agent_generation_config"] = ( + GenerationConfig( + **( + request.kg_search_settings[ + "agent_generation_config" + ] + or {} + ) + ) + ) + response = await self.engine.arag( + query=request.query, + vector_search_settings=VectorSearchSettings( + **(request.vector_search_settings or {}) + ), + kg_search_settings=KGSearchSettings( + **(request.kg_search_settings or {}) + ), + rag_generation_config=GenerationConfig( + **(request.rag_generation_config or {}) + ), + ) + if ( + request.rag_generation_config + and request.rag_generation_config.get("stream", False) + ): + + async def stream_generator(): + async for chunk in response: + yield chunk + + return StreamingResponse( + stream_generator(), media_type="application/json" + ) + else: + return response + + @self.router.post("/evaluate") + @self.base_endpoint + async def evaluate_app(request: R2REvalRequest): + results = await self.engine.aevaluate( + query=request.query, + context=request.context, + completion=request.completion, + ) + return results + + +def create_retrieval_router(engine: R2REngine): + return RetrievalRouter(engine).router diff --git a/R2R/r2r/main/app.py b/R2R/r2r/main/app.py new file mode 100755 index 00000000..981445e4 --- /dev/null +++ b/R2R/r2r/main/app.py @@ -0,0 +1,53 @@ +from fastapi import FastAPI + +from .engine import R2REngine + + +class R2RApp: + def __init__(self, engine: R2REngine): + self.engine = engine + self._setup_routes() + self._apply_cors() + + async def openapi_spec(self, *args, **kwargs): + from fastapi.openapi.utils import get_openapi + + return get_openapi( + title="R2R Application API", + version="1.0.0", + routes=self.app.routes, + ) + + def _setup_routes(self): + from .api.routes import ingestion, management, retrieval + + self.app = FastAPI() + + # Create routers with the engine + ingestion_router = ingestion.IngestionRouter.build_router(self.engine) + management_router = management.ManagementRouter.build_router( + self.engine + ) + retrieval_router = retrieval.RetrievalRouter.build_router(self.engine) + + # Include routers in the app + self.app.include_router(ingestion_router, prefix="/v1") + self.app.include_router(management_router, prefix="/v1") + self.app.include_router(retrieval_router, prefix="/v1") + + def _apply_cors(self): + from fastapi.middleware.cors import CORSMiddleware + + origins = ["*", "http://localhost:3000", "http://localhost:8000"] + self.app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + ) + + def serve(self, host: str = "0.0.0.0", port: int = 8000): + import uvicorn + + uvicorn.run(self.app, host=host, port=port) diff --git a/R2R/r2r/main/app_entry.py b/R2R/r2r/main/app_entry.py new file mode 100755 index 00000000..29b705d7 --- /dev/null +++ b/R2R/r2r/main/app_entry.py @@ -0,0 +1,84 @@ +import logging +import os +from enum import Enum +from typing import Optional + +from fastapi import FastAPI + +from r2r import R2RBuilder, R2RConfig +from r2r.main.execution import R2RExecutionWrapper + +logger = logging.getLogger(__name__) +current_file_path = os.path.dirname(__file__) +configs_path = os.path.join(current_file_path, "..", "..", "..") + + +class PipelineType(Enum): + QNA = "qna" + WEB = "web" + HYDE = "hyde" + + +def r2r_app( + config_name: Optional[str] = "default", + config_path: Optional[str] = None, + client_mode: bool = False, + base_url: Optional[str] = None, + pipeline_type: PipelineType = PipelineType.QNA, +) -> FastAPI: + if pipeline_type != PipelineType.QNA: + raise ValueError("Only QNA pipeline is supported in quickstart.") + if config_path and config_name: + raise ValueError("Cannot specify both config and config_name") + + if config_path: + config = R2RConfig.from_json(config_path) + else: + config_name = os.getenv("CONFIG_NAME") or config_name + if config_name not in R2RBuilder.CONFIG_OPTIONS: + raise ValueError(f"Invalid config name: {config_name}") + config = R2RConfig.from_json(R2RBuilder.CONFIG_OPTIONS[config_name]) + + if ( + config.embedding.provider == "openai" + and "OPENAI_API_KEY" not in os.environ + ): + raise ValueError( + "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider." + ) + + wrapper = R2RExecutionWrapper( + config_name=config_name, + config_path=config_path, + client_mode=client_mode, + base_url=base_url, + ) + + return wrapper.get_app() + + +logging.basicConfig(level=logging.INFO) + +config_name = os.getenv("CONFIG_NAME", None) +config_path = os.getenv("CONFIG_PATH", None) +if not config_path and not config_name: + config_name = "default" +client_mode = os.getenv("CLIENT_MODE", "false").lower() == "true" +base_url = os.getenv("BASE_URL") +host = os.getenv("HOST", "0.0.0.0") +port = int(os.getenv("PORT", "8000")) +pipeline_type = os.getenv("PIPELINE_TYPE", "qna") + +logger.info(f"Environment CONFIG_NAME: {config_name}") +logger.info(f"Environment CONFIG_PATH: {config_path}") +logger.info(f"Environment CLIENT_MODE: {client_mode}") +logger.info(f"Environment BASE_URL: {base_url}") +logger.info(f"Environment PIPELINE_TYPE: {pipeline_type}") + +app = r2r_app( + config_name=config_name, + config_path=config_path, + client_mode=client_mode, + base_url=base_url, + pipeline_type=PipelineType(pipeline_type), +) diff --git a/R2R/r2r/main/assembly/__init__.py b/R2R/r2r/main/assembly/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/main/assembly/__init__.py diff --git a/R2R/r2r/main/assembly/builder.py b/R2R/r2r/main/assembly/builder.py new file mode 100755 index 00000000..863fc6d0 --- /dev/null +++ b/R2R/r2r/main/assembly/builder.py @@ -0,0 +1,207 @@ +import os +from typing import Optional, Type + +from r2r.base import ( + AsyncPipe, + EmbeddingProvider, + EvalProvider, + LLMProvider, + PromptProvider, + VectorDBProvider, +) +from r2r.pipelines import ( + EvalPipeline, + IngestionPipeline, + RAGPipeline, + SearchPipeline, +) + +from ..app import R2RApp +from ..engine import R2REngine +from ..r2r import R2R +from .config import R2RConfig +from .factory import R2RPipeFactory, R2RPipelineFactory, R2RProviderFactory + + +class R2RBuilder: + current_file_path = os.path.dirname(__file__) + config_root = os.path.join( + current_file_path, "..", "..", "examples", "configs" + ) + CONFIG_OPTIONS = { + "default": None, + "local_ollama": os.path.join(config_root, "local_ollama.json"), + "local_ollama_rerank": os.path.join( + config_root, "local_ollama_rerank.json" + ), + "neo4j_kg": os.path.join(config_root, "neo4j_kg.json"), + "local_neo4j_kg": os.path.join(config_root, "local_neo4j_kg.json"), + "postgres_logging": os.path.join(config_root, "postgres_logging.json"), + } + + @staticmethod + def _get_config(config_name): + if config_name is None: + return R2RConfig.from_json() + if config_name in R2RBuilder.CONFIG_OPTIONS: + return R2RConfig.from_json(R2RBuilder.CONFIG_OPTIONS[config_name]) + raise ValueError(f"Invalid config name: {config_name}") + + def __init__( + self, + config: Optional[R2RConfig] = None, + from_config: Optional[str] = None, + ): + if config and from_config: + raise ValueError("Cannot specify both config and config_name") + self.config = config or R2RBuilder._get_config(from_config) + self.r2r_app_override: Optional[Type[R2REngine]] = None + self.provider_factory_override: Optional[Type[R2RProviderFactory]] = ( + None + ) + self.pipe_factory_override: Optional[R2RPipeFactory] = None + self.pipeline_factory_override: Optional[R2RPipelineFactory] = None + self.vector_db_provider_override: Optional[VectorDBProvider] = None + self.embedding_provider_override: Optional[EmbeddingProvider] = None + self.eval_provider_override: Optional[EvalProvider] = None + self.llm_provider_override: Optional[LLMProvider] = None + self.prompt_provider_override: Optional[PromptProvider] = None + self.parsing_pipe_override: Optional[AsyncPipe] = None + self.embedding_pipe_override: Optional[AsyncPipe] = None + self.vector_storage_pipe_override: Optional[AsyncPipe] = None + self.vector_search_pipe_override: Optional[AsyncPipe] = None + self.rag_pipe_override: Optional[AsyncPipe] = None + self.streaming_rag_pipe_override: Optional[AsyncPipe] = None + self.eval_pipe_override: Optional[AsyncPipe] = None + self.ingestion_pipeline: Optional[IngestionPipeline] = None + self.search_pipeline: Optional[SearchPipeline] = None + self.rag_pipeline: Optional[RAGPipeline] = None + self.streaming_rag_pipeline: Optional[RAGPipeline] = None + self.eval_pipeline: Optional[EvalPipeline] = None + + def with_app(self, app: Type[R2REngine]): + self.r2r_app_override = app + return self + + def with_provider_factory(self, factory: Type[R2RProviderFactory]): + self.provider_factory_override = factory + return self + + def with_pipe_factory(self, factory: R2RPipeFactory): + self.pipe_factory_override = factory + return self + + def with_pipeline_factory(self, factory: R2RPipelineFactory): + self.pipeline_factory_override = factory + return self + + def with_vector_db_provider(self, provider: VectorDBProvider): + self.vector_db_provider_override = provider + return self + + def with_embedding_provider(self, provider: EmbeddingProvider): + self.embedding_provider_override = provider + return self + + def with_eval_provider(self, provider: EvalProvider): + self.eval_provider_override = provider + return self + + def with_llm_provider(self, provider: LLMProvider): + self.llm_provider_override = provider + return self + + def with_prompt_provider(self, provider: PromptProvider): + self.prompt_provider_override = provider + return self + + def with_parsing_pipe(self, pipe: AsyncPipe): + self.parsing_pipe_override = pipe + return self + + def with_embedding_pipe(self, pipe: AsyncPipe): + self.embedding_pipe_override = pipe + return self + + def with_vector_storage_pipe(self, pipe: AsyncPipe): + self.vector_storage_pipe_override = pipe + return self + + def with_vector_search_pipe(self, pipe: AsyncPipe): + self.vector_search_pipe_override = pipe + return self + + def with_rag_pipe(self, pipe: AsyncPipe): + self.rag_pipe_override = pipe + return self + + def with_streaming_rag_pipe(self, pipe: AsyncPipe): + self.streaming_rag_pipe_override = pipe + return self + + def with_eval_pipe(self, pipe: AsyncPipe): + self.eval_pipe_override = pipe + return self + + def with_ingestion_pipeline(self, pipeline: IngestionPipeline): + self.ingestion_pipeline = pipeline + return self + + def with_vector_search_pipeline(self, pipeline: SearchPipeline): + self.search_pipeline = pipeline + return self + + def with_rag_pipeline(self, pipeline: RAGPipeline): + self.rag_pipeline = pipeline + return self + + def with_streaming_rag_pipeline(self, pipeline: RAGPipeline): + self.streaming_rag_pipeline = pipeline + return self + + def with_eval_pipeline(self, pipeline: EvalPipeline): + self.eval_pipeline = pipeline + return self + + def build(self, *args, **kwargs) -> R2R: + provider_factory = self.provider_factory_override or R2RProviderFactory + pipe_factory = self.pipe_factory_override or R2RPipeFactory + pipeline_factory = self.pipeline_factory_override or R2RPipelineFactory + + providers = provider_factory(self.config).create_providers( + vector_db_provider_override=self.vector_db_provider_override, + embedding_provider_override=self.embedding_provider_override, + eval_provider_override=self.eval_provider_override, + llm_provider_override=self.llm_provider_override, + prompt_provider_override=self.prompt_provider_override, + *args, + **kwargs, + ) + + pipes = pipe_factory(self.config, providers).create_pipes( + parsing_pipe_override=self.parsing_pipe_override, + embedding_pipe_override=self.embedding_pipe_override, + vector_storage_pipe_override=self.vector_storage_pipe_override, + vector_search_pipe_override=self.vector_search_pipe_override, + rag_pipe_override=self.rag_pipe_override, + streaming_rag_pipe_override=self.streaming_rag_pipe_override, + eval_pipe_override=self.eval_pipe_override, + *args, + **kwargs, + ) + + pipelines = pipeline_factory(self.config, pipes).create_pipelines( + ingestion_pipeline=self.ingestion_pipeline, + search_pipeline=self.search_pipeline, + rag_pipeline=self.rag_pipeline, + streaming_rag_pipeline=self.streaming_rag_pipeline, + eval_pipeline=self.eval_pipeline, + *args, + **kwargs, + ) + + engine = (self.r2r_app_override or R2REngine)( + self.config, providers, pipelines + ) + r2r_app = R2RApp(engine) + return R2R(engine=engine, app=r2r_app) diff --git a/R2R/r2r/main/assembly/config.py b/R2R/r2r/main/assembly/config.py new file mode 100755 index 00000000..d52c4561 --- /dev/null +++ b/R2R/r2r/main/assembly/config.py @@ -0,0 +1,167 @@ +import json +import logging +import os +from enum import Enum +from typing import Any + +from ...base.abstractions.document import DocumentType +from ...base.abstractions.llm import GenerationConfig +from ...base.logging.kv_logger import LoggingConfig +from ...base.providers.embedding_provider import EmbeddingConfig +from ...base.providers.eval_provider import EvalConfig +from ...base.providers.kg_provider import KGConfig +from ...base.providers.llm_provider import LLMConfig +from ...base.providers.prompt_provider import PromptConfig +from ...base.providers.vector_db_provider import ProviderConfig, VectorDBConfig + +logger = logging.getLogger(__name__) + + +class R2RConfig: + REQUIRED_KEYS: dict[str, list] = { + "app": ["max_file_size_in_mb"], + "embedding": [ + "provider", + "base_model", + "base_dimension", + "batch_size", + "text_splitter", + ], + "eval": ["llm"], + "kg": [ + "provider", + "batch_size", + "kg_extraction_config", + "text_splitter", + ], + "ingestion": ["excluded_parsers"], + "completions": ["provider"], + "logging": ["provider", "log_table"], + "prompt": ["provider"], + "vector_database": ["provider"], + } + app: dict[str, Any] + embedding: EmbeddingConfig + completions: LLMConfig + logging: LoggingConfig + prompt: PromptConfig + vector_database: VectorDBConfig + + def __init__(self, config_data: dict[str, Any]): + # Load the default configuration + default_config = self.load_default_config() + + # Override the default configuration with the passed configuration + for key in config_data: + if key in default_config: + default_config[key].update(config_data[key]) + else: + default_config[key] = config_data[key] + + # Validate and set the configuration + for section, keys in R2RConfig.REQUIRED_KEYS.items(): + # Check the keys when provider is set + # TODO - Clean up robust null checks + if "provider" in default_config[section] and ( + default_config[section]["provider"] is not None + and default_config[section]["provider"] != "None" + and default_config[section]["provider"] != "null" + ): + self._validate_config_section(default_config, section, keys) + setattr(self, section, default_config[section]) + + self.app = self.app # for type hinting + self.ingestion = self.ingestion # for type hinting + self.ingestion["excluded_parsers"] = [ + DocumentType(k) for k in self.ingestion["excluded_parsers"] + ] + # override GenerationConfig defaults + GenerationConfig.set_default( + **self.completions.get("generation_config", {}) + ) + self.embedding = EmbeddingConfig.create(**self.embedding) + self.kg = KGConfig.create(**self.kg) + eval_llm = self.eval.pop("llm", None) + self.eval = EvalConfig.create( + **self.eval, llm=LLMConfig.create(**eval_llm) if eval_llm else None + ) + self.completions = LLMConfig.create(**self.completions) + self.logging = LoggingConfig.create(**self.logging) + self.prompt = PromptConfig.create(**self.prompt) + self.vector_database = VectorDBConfig.create(**self.vector_database) + + def _validate_config_section( + self, config_data: dict[str, Any], section: str, keys: list + ): + if section not in config_data: + raise ValueError(f"Missing '{section}' section in config") + if not all(key in config_data[section] for key in keys): + raise ValueError(f"Missing required keys in '{section}' config") + + @classmethod + def from_json(cls, config_path: str = None) -> "R2RConfig": + if config_path is None: + # Get the root directory of the project + file_dir = os.path.dirname(os.path.abspath(__file__)) + config_path = os.path.join( + file_dir, "..", "..", "..", "config.json" + ) + + # Load configuration from JSON file + with open(config_path) as f: + config_data = json.load(f) + + return cls(config_data) + + def to_json(self): + config_data = { + section: self._serialize_config(getattr(self, section)) + for section in R2RConfig.REQUIRED_KEYS.keys() + } + return json.dumps(config_data) + + def save_to_redis(self, redis_client: Any, key: str): + redis_client.set(f"R2RConfig:{key}", self.to_json()) + + @classmethod + def load_from_redis(cls, redis_client: Any, key: str) -> "R2RConfig": + config_data = redis_client.get(f"R2RConfig:{key}") + if config_data is None: + raise ValueError( + f"Configuration not found in Redis with key '{key}'" + ) + config_data = json.loads(config_data) + # config_data["ingestion"]["selected_parsers"] = { + # DocumentType(k): v + # for k, v in config_data["ingestion"]["selected_parsers"].items() + # } + return cls(config_data) + + @classmethod + def load_default_config(cls) -> dict: + # Get the root directory of the project + file_dir = os.path.dirname(os.path.abspath(__file__)) + default_config_path = os.path.join( + file_dir, "..", "..", "..", "config.json" + ) + # Load default configuration from JSON file + with open(default_config_path) as f: + return json.load(f) + + @staticmethod + def _serialize_config(config_section: Any) -> dict: + # TODO - Make this approach cleaner + if isinstance(config_section, ProviderConfig): + config_section = config_section.dict() + filtered_result = {} + for k, v in config_section.items(): + if isinstance(k, Enum): + k = k.value + if isinstance(v, dict): + formatted_v = { + k2.value if isinstance(k2, Enum) else k2: v2 + for k2, v2 in v.items() + } + v = formatted_v + filtered_result[k] = v + return filtered_result diff --git a/R2R/r2r/main/assembly/factory.py b/R2R/r2r/main/assembly/factory.py new file mode 100755 index 00000000..4e147337 --- /dev/null +++ b/R2R/r2r/main/assembly/factory.py @@ -0,0 +1,484 @@ +import logging +import os +from typing import Any, Optional + +from r2r.base import ( + AsyncPipe, + EmbeddingConfig, + EmbeddingProvider, + EvalProvider, + KGProvider, + KVLoggingSingleton, + LLMConfig, + LLMProvider, + PromptProvider, + VectorDBConfig, + VectorDBProvider, +) +from r2r.pipelines import ( + EvalPipeline, + IngestionPipeline, + RAGPipeline, + SearchPipeline, +) + +from ..abstractions import R2RPipelines, R2RPipes, R2RProviders +from .config import R2RConfig + +logger = logging.getLogger(__name__) + + +class R2RProviderFactory: + def __init__(self, config: R2RConfig): + self.config = config + + def create_vector_db_provider( + self, vector_db_config: VectorDBConfig, *args, **kwargs + ) -> VectorDBProvider: + vector_db_provider: Optional[VectorDBProvider] = None + if vector_db_config.provider == "pgvector": + from r2r.providers.vector_dbs import PGVectorDB + + vector_db_provider = PGVectorDB(vector_db_config) + else: + raise ValueError( + f"Vector database provider {vector_db_config.provider} not supported" + ) + if not vector_db_provider: + raise ValueError("Vector database provider not found") + + if not self.config.embedding.base_dimension: + raise ValueError("Search dimension not found in embedding config") + + vector_db_provider.initialize_collection( + self.config.embedding.base_dimension + ) + return vector_db_provider + + def create_embedding_provider( + self, embedding: EmbeddingConfig, *args, **kwargs + ) -> EmbeddingProvider: + embedding_provider: Optional[EmbeddingProvider] = None + + if embedding.provider == "openai": + if not os.getenv("OPENAI_API_KEY"): + raise ValueError( + "Must set OPENAI_API_KEY in order to initialize OpenAIEmbeddingProvider." + ) + from r2r.providers.embeddings import OpenAIEmbeddingProvider + + embedding_provider = OpenAIEmbeddingProvider(embedding) + elif embedding.provider == "ollama": + from r2r.providers.embeddings import OllamaEmbeddingProvider + + embedding_provider = OllamaEmbeddingProvider(embedding) + + elif embedding.provider == "sentence-transformers": + from r2r.providers.embeddings import ( + SentenceTransformerEmbeddingProvider, + ) + + embedding_provider = SentenceTransformerEmbeddingProvider( + embedding + ) + elif embedding is None: + embedding_provider = None + else: + raise ValueError( + f"Embedding provider {embedding.provider} not supported" + ) + + return embedding_provider + + def create_eval_provider( + self, eval_config, prompt_provider, *args, **kwargs + ) -> Optional[EvalProvider]: + if eval_config.provider == "local": + from r2r.providers.eval import LLMEvalProvider + + llm_provider = self.create_llm_provider(eval_config.llm) + eval_provider = LLMEvalProvider( + eval_config, + llm_provider=llm_provider, + prompt_provider=prompt_provider, + ) + elif eval_config.provider is None: + eval_provider = None + else: + raise ValueError( + f"Eval provider {eval_config.provider} not supported." + ) + + return eval_provider + + def create_llm_provider( + self, llm_config: LLMConfig, *args, **kwargs + ) -> LLMProvider: + llm_provider: Optional[LLMProvider] = None + if llm_config.provider == "openai": + from r2r.providers.llms import OpenAILLM + + llm_provider = OpenAILLM(llm_config) + elif llm_config.provider == "litellm": + from r2r.providers.llms import LiteLLM + + llm_provider = LiteLLM(llm_config) + else: + raise ValueError( + f"Language model provider {llm_config.provider} not supported" + ) + if not llm_provider: + raise ValueError("Language model provider not found") + return llm_provider + + def create_prompt_provider( + self, prompt_config, *args, **kwargs + ) -> PromptProvider: + prompt_provider = None + if prompt_config.provider == "local": + from r2r.prompts import R2RPromptProvider + + prompt_provider = R2RPromptProvider() + else: + raise ValueError( + f"Prompt provider {prompt_config.provider} not supported" + ) + return prompt_provider + + def create_kg_provider(self, kg_config, *args, **kwargs): + if kg_config.provider == "neo4j": + from r2r.providers.kg import Neo4jKGProvider + + return Neo4jKGProvider(kg_config) + elif kg_config.provider is None: + return None + else: + raise ValueError( + f"KG provider {kg_config.provider} not supported." + ) + + def create_providers( + self, + vector_db_provider_override: Optional[VectorDBProvider] = None, + embedding_provider_override: Optional[EmbeddingProvider] = None, + eval_provider_override: Optional[EvalProvider] = None, + llm_provider_override: Optional[LLMProvider] = None, + prompt_provider_override: Optional[PromptProvider] = None, + kg_provider_override: Optional[KGProvider] = None, + *args, + **kwargs, + ) -> R2RProviders: + prompt_provider = ( + prompt_provider_override + or self.create_prompt_provider(self.config.prompt, *args, **kwargs) + ) + return R2RProviders( + vector_db=vector_db_provider_override + or self.create_vector_db_provider( + self.config.vector_database, *args, **kwargs + ), + embedding=embedding_provider_override + or self.create_embedding_provider( + self.config.embedding, *args, **kwargs + ), + eval=eval_provider_override + or self.create_eval_provider( + self.config.eval, + prompt_provider=prompt_provider, + *args, + **kwargs, + ), + llm=llm_provider_override + or self.create_llm_provider( + self.config.completions, *args, **kwargs + ), + prompt=prompt_provider_override + or self.create_prompt_provider( + self.config.prompt, *args, **kwargs + ), + kg=kg_provider_override + or self.create_kg_provider(self.config.kg, *args, **kwargs), + ) + + +class R2RPipeFactory: + def __init__(self, config: R2RConfig, providers: R2RProviders): + self.config = config + self.providers = providers + + def create_pipes( + self, + parsing_pipe_override: Optional[AsyncPipe] = None, + embedding_pipe_override: Optional[AsyncPipe] = None, + kg_pipe_override: Optional[AsyncPipe] = None, + kg_storage_pipe_override: Optional[AsyncPipe] = None, + kg_agent_pipe_override: Optional[AsyncPipe] = None, + vector_storage_pipe_override: Optional[AsyncPipe] = None, + vector_search_pipe_override: Optional[AsyncPipe] = None, + rag_pipe_override: Optional[AsyncPipe] = None, + streaming_rag_pipe_override: Optional[AsyncPipe] = None, + eval_pipe_override: Optional[AsyncPipe] = None, + *args, + **kwargs, + ) -> R2RPipes: + return R2RPipes( + parsing_pipe=parsing_pipe_override + or self.create_parsing_pipe( + self.config.ingestion.get("excluded_parsers"), *args, **kwargs + ), + embedding_pipe=embedding_pipe_override + or self.create_embedding_pipe(*args, **kwargs), + kg_pipe=kg_pipe_override or self.create_kg_pipe(*args, **kwargs), + kg_storage_pipe=kg_storage_pipe_override + or self.create_kg_storage_pipe(*args, **kwargs), + kg_agent_search_pipe=kg_agent_pipe_override + or self.create_kg_agent_pipe(*args, **kwargs), + vector_storage_pipe=vector_storage_pipe_override + or self.create_vector_storage_pipe(*args, **kwargs), + vector_search_pipe=vector_search_pipe_override + or self.create_vector_search_pipe(*args, **kwargs), + rag_pipe=rag_pipe_override + or self.create_rag_pipe(*args, **kwargs), + streaming_rag_pipe=streaming_rag_pipe_override + or self.create_rag_pipe(stream=True, *args, **kwargs), + eval_pipe=eval_pipe_override + or self.create_eval_pipe(*args, **kwargs), + ) + + def create_parsing_pipe( + self, excluded_parsers: Optional[list] = None, *args, **kwargs + ) -> Any: + from r2r.pipes import ParsingPipe + + return ParsingPipe(excluded_parsers=excluded_parsers or []) + + def create_embedding_pipe(self, *args, **kwargs) -> Any: + if self.config.embedding.provider is None: + return None + + from r2r.base import RecursiveCharacterTextSplitter + from r2r.pipes import EmbeddingPipe + + text_splitter_config = self.config.embedding.extra_fields.get( + "text_splitter" + ) + if not text_splitter_config: + raise ValueError( + "Text splitter config not found in embedding config" + ) + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=text_splitter_config["chunk_size"], + chunk_overlap=text_splitter_config["chunk_overlap"], + length_function=len, + is_separator_regex=False, + ) + return EmbeddingPipe( + embedding_provider=self.providers.embedding, + vector_db_provider=self.providers.vector_db, + text_splitter=text_splitter, + embedding_batch_size=self.config.embedding.batch_size, + ) + + def create_vector_storage_pipe(self, *args, **kwargs) -> Any: + if self.config.embedding.provider is None: + return None + + from r2r.pipes import VectorStoragePipe + + return VectorStoragePipe(vector_db_provider=self.providers.vector_db) + + def create_vector_search_pipe(self, *args, **kwargs) -> Any: + if self.config.embedding.provider is None: + return None + + from r2r.pipes import VectorSearchPipe + + return VectorSearchPipe( + vector_db_provider=self.providers.vector_db, + embedding_provider=self.providers.embedding, + ) + + def create_kg_pipe(self, *args, **kwargs) -> Any: + if self.config.kg.provider is None: + return None + + from r2r.base import RecursiveCharacterTextSplitter + from r2r.pipes import KGExtractionPipe + + text_splitter_config = self.config.kg.extra_fields.get("text_splitter") + if not text_splitter_config: + raise ValueError("Text splitter config not found in kg config.") + + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=text_splitter_config["chunk_size"], + chunk_overlap=text_splitter_config["chunk_overlap"], + length_function=len, + is_separator_regex=False, + ) + return KGExtractionPipe( + kg_provider=self.providers.kg, + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + vector_db_provider=self.providers.vector_db, + text_splitter=text_splitter, + kg_batch_size=self.config.kg.batch_size, + ) + + def create_kg_storage_pipe(self, *args, **kwargs) -> Any: + if self.config.kg.provider is None: + return None + + from r2r.pipes import KGStoragePipe + + return KGStoragePipe( + kg_provider=self.providers.kg, + embedding_provider=self.providers.embedding, + ) + + def create_kg_agent_pipe(self, *args, **kwargs) -> Any: + if self.config.kg.provider is None: + return None + + from r2r.pipes import KGAgentSearchPipe + + return KGAgentSearchPipe( + kg_provider=self.providers.kg, + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + ) + + def create_rag_pipe(self, stream: bool = False, *args, **kwargs) -> Any: + if stream: + from r2r.pipes import StreamingSearchRAGPipe + + return StreamingSearchRAGPipe( + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + ) + else: + from r2r.pipes import SearchRAGPipe + + return SearchRAGPipe( + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + ) + + def create_eval_pipe(self, *args, **kwargs) -> Any: + from r2r.pipes import EvalPipe + + return EvalPipe(eval_provider=self.providers.eval) + + +class R2RPipelineFactory: + def __init__(self, config: R2RConfig, pipes: R2RPipes): + self.config = config + self.pipes = pipes + + def create_ingestion_pipeline(self, *args, **kwargs) -> IngestionPipeline: + """factory method to create an ingestion pipeline.""" + ingestion_pipeline = IngestionPipeline() + + ingestion_pipeline.add_pipe( + pipe=self.pipes.parsing_pipe, parsing_pipe=True + ) + # Add embedding pipes if provider is set + if self.config.embedding.provider is not None: + ingestion_pipeline.add_pipe( + self.pipes.embedding_pipe, embedding_pipe=True + ) + ingestion_pipeline.add_pipe( + self.pipes.vector_storage_pipe, embedding_pipe=True + ) + # Add KG pipes if provider is set + if self.config.kg.provider is not None: + ingestion_pipeline.add_pipe(self.pipes.kg_pipe, kg_pipe=True) + ingestion_pipeline.add_pipe( + self.pipes.kg_storage_pipe, kg_pipe=True + ) + + return ingestion_pipeline + + def create_search_pipeline(self, *args, **kwargs) -> SearchPipeline: + """factory method to create an ingestion pipeline.""" + search_pipeline = SearchPipeline() + + # Add vector search pipes if embedding provider and vector provider is set + if ( + self.config.embedding.provider is not None + and self.config.vector_database.provider is not None + ): + search_pipeline.add_pipe( + self.pipes.vector_search_pipe, vector_search_pipe=True + ) + + # Add KG pipes if provider is set + if self.config.kg.provider is not None: + search_pipeline.add_pipe( + self.pipes.kg_agent_search_pipe, kg_pipe=True + ) + + return search_pipeline + + def create_rag_pipeline( + self, + search_pipeline: SearchPipeline, + stream: bool = False, + *args, + **kwargs, + ) -> RAGPipeline: + rag_pipe = ( + self.pipes.streaming_rag_pipe if stream else self.pipes.rag_pipe + ) + + rag_pipeline = RAGPipeline() + rag_pipeline.set_search_pipeline(search_pipeline) + rag_pipeline.add_pipe(rag_pipe) + return rag_pipeline + + def create_eval_pipeline(self, *args, **kwargs) -> EvalPipeline: + eval_pipeline = EvalPipeline() + eval_pipeline.add_pipe(self.pipes.eval_pipe) + return eval_pipeline + + def create_pipelines( + self, + ingestion_pipeline: Optional[IngestionPipeline] = None, + search_pipeline: Optional[SearchPipeline] = None, + rag_pipeline: Optional[RAGPipeline] = None, + streaming_rag_pipeline: Optional[RAGPipeline] = None, + eval_pipeline: Optional[EvalPipeline] = None, + *args, + **kwargs, + ) -> R2RPipelines: + try: + self.configure_logging() + except Exception as e: + logger.warn(f"Error configuring logging: {e}") + search_pipeline = search_pipeline or self.create_search_pipeline( + *args, **kwargs + ) + return R2RPipelines( + ingestion_pipeline=ingestion_pipeline + or self.create_ingestion_pipeline(*args, **kwargs), + search_pipeline=search_pipeline, + rag_pipeline=rag_pipeline + or self.create_rag_pipeline( + search_pipeline=search_pipeline, + stream=False, + *args, + **kwargs, + ), + streaming_rag_pipeline=streaming_rag_pipeline + or self.create_rag_pipeline( + search_pipeline=search_pipeline, + stream=True, + *args, + **kwargs, + ), + eval_pipeline=eval_pipeline + or self.create_eval_pipeline(*args, **kwargs), + ) + + def configure_logging(self): + KVLoggingSingleton.configure(self.config.logging) diff --git a/R2R/r2r/main/assembly/factory_extensions.py b/R2R/r2r/main/assembly/factory_extensions.py new file mode 100755 index 00000000..56e82ef7 --- /dev/null +++ b/R2R/r2r/main/assembly/factory_extensions.py @@ -0,0 +1,69 @@ +from r2r.main import R2RPipeFactory +from r2r.pipes.retrieval.multi_search import MultiSearchPipe +from r2r.pipes.retrieval.query_transform_pipe import QueryTransformPipe + + +class R2RPipeFactoryWithMultiSearch(R2RPipeFactory): + QUERY_GENERATION_TEMPLATE: dict = ( + { # TODO - Can we have stricter typing like so? `: {"template": str, "input_types": dict[str, str]} = {`` + "template": "### Instruction:\n\nGiven the following query that follows to write a double newline separated list of up to {num_outputs} queries meant to help answer the original query. \nDO NOT generate any single query which is likely to require information from multiple distinct documents, \nEACH single query will be used to carry out a cosine similarity semantic search over distinct indexed documents, such as varied medical documents. \nFOR EXAMPLE if asked `how do the key themes of Great Gatsby compare with 1984`, the two queries would be \n`What are the key themes of Great Gatsby?` and `What are the key themes of 1984?`.\nHere is the original user query to be transformed into answers:\n\n### Query:\n{message}\n\n### Response:\n", + "input_types": {"num_outputs": "int", "message": "str"}, + } + ) + + def create_vector_search_pipe(self, *args, **kwargs): + """ + A factory method to create a search pipe. + + Overrides include + task_prompt_name: str + multi_query_transform_pipe_override: QueryTransformPipe + multi_inner_search_pipe_override: SearchPipe + query_generation_template_override: {'template': str, 'input_types': dict[str, str]} + """ + multi_search_config = MultiSearchPipe.PipeConfig() + if kwargs.get("task_prompt_name") and kwargs.get( + "query_generation_template_override" + ): + raise ValueError( + "Cannot provide both `task_prompt_name` and `query_generation_template_override`" + ) + task_prompt_name = ( + kwargs.get("task_prompt_name") + or f"{multi_search_config.name}_task_prompt" + ) + if kwargs.get("query_generation_template_override"): + # Add a prompt for transforming the user query + template = kwargs.get("query_generation_template_override") + self.providers.prompt.add_prompt( + **( + kwargs.get("query_generation_template_override") + or self.QUERY_GENERATION_TEMPLATE + ), + ) + task_prompt_name = template["name"] + + # Initialize the new query transform pipe + query_transform_pipe = kwargs.get( + "multi_query_transform_pipe_override", None + ) or QueryTransformPipe( + llm_provider=self.providers.llm, + prompt_provider=self.providers.prompt, + config=QueryTransformPipe.QueryTransformConfig( + name=multi_search_config.name, + task_prompt=task_prompt_name, + ), + ) + # Create search pipe override and pipes + inner_search_pipe = kwargs.get( + "multi_inner_search_pipe_override", None + ) or super().create_vector_search_pipe(*args, **kwargs) + + # TODO - modify `create_..._pipe` to allow naming the pipe + inner_search_pipe.config.name = multi_search_config.name + + return MultiSearchPipe( + query_transform_pipe=query_transform_pipe, + inner_search_pipe=inner_search_pipe, + config=multi_search_config, + ) diff --git a/R2R/r2r/main/engine.py b/R2R/r2r/main/engine.py new file mode 100755 index 00000000..a73b932e --- /dev/null +++ b/R2R/r2r/main/engine.py @@ -0,0 +1,109 @@ +from typing import Optional + +from r2r.base import KVLoggingSingleton, RunManager +from r2r.base.abstractions.base import AsyncSyncMeta, syncable + +from .abstractions import R2RPipelines, R2RProviders +from .assembly.config import R2RConfig +from .services.ingestion_service import IngestionService +from .services.management_service import ManagementService +from .services.retrieval_service import RetrievalService + + +class R2REngine(metaclass=AsyncSyncMeta): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + pipelines: R2RPipelines, + run_manager: Optional[RunManager] = None, + ): + logging_connection = KVLoggingSingleton() + run_manager = run_manager or RunManager(logging_connection) + + self.config = config + self.providers = providers + self.pipelines = pipelines + self.logging_connection = KVLoggingSingleton() + self.run_manager = run_manager + + self.ingestion_service = IngestionService( + config, providers, pipelines, run_manager, logging_connection + ) + self.retrieval_service = RetrievalService( + config, providers, pipelines, run_manager, logging_connection + ) + self.management_service = ManagementService( + config, providers, pipelines, run_manager, logging_connection + ) + + # Ingestion routes + @syncable + async def aingest_documents(self, *args, **kwargs): + return await self.ingestion_service.ingest_documents(*args, **kwargs) + + @syncable + async def aupdate_documents(self, *args, **kwargs): + return await self.ingestion_service.update_documents(*args, **kwargs) + + @syncable + async def aingest_files(self, *args, **kwargs): + return await self.ingestion_service.ingest_files(*args, **kwargs) + + @syncable + async def aupdate_files(self, *args, **kwargs): + return await self.ingestion_service.update_files(*args, **kwargs) + + # Retrieval routes + @syncable + async def asearch(self, *args, **kwargs): + return await self.retrieval_service.search(*args, **kwargs) + + @syncable + async def arag(self, *args, **kwargs): + return await self.retrieval_service.rag(*args, **kwargs) + + @syncable + async def aevaluate(self, *args, **kwargs): + return await self.retrieval_service.evaluate(*args, **kwargs) + + # Management routes + @syncable + async def aupdate_prompt(self, *args, **kwargs): + return await self.management_service.update_prompt(*args, **kwargs) + + @syncable + async def alogs(self, *args, **kwargs): + return await self.management_service.alogs(*args, **kwargs) + + @syncable + async def aanalytics(self, *args, **kwargs): + return await self.management_service.aanalytics(*args, **kwargs) + + @syncable + async def aapp_settings(self, *args, **kwargs): + return await self.management_service.aapp_settings(*args, **kwargs) + + @syncable + async def ausers_overview(self, *args, **kwargs): + return await self.management_service.ausers_overview(*args, **kwargs) + + @syncable + async def adelete(self, *args, **kwargs): + return await self.management_service.delete(*args, **kwargs) + + @syncable + async def adocuments_overview(self, *args, **kwargs): + return await self.management_service.adocuments_overview( + *args, **kwargs + ) + + @syncable + async def inspect_knowledge_graph(self, *args, **kwargs): + return await self.management_service.inspect_knowledge_graph( + *args, **kwargs + ) + + @syncable + async def adocument_chunks(self, *args, **kwargs): + return await self.management_service.document_chunks(*args, **kwargs) diff --git a/R2R/r2r/main/execution.py b/R2R/r2r/main/execution.py new file mode 100755 index 00000000..187a2eea --- /dev/null +++ b/R2R/r2r/main/execution.py @@ -0,0 +1,421 @@ +import ast +import asyncio +import json +import os +import uuid +from typing import Optional, Union + +from fastapi import UploadFile + +from r2r.base import ( + AnalysisTypes, + FilterCriteria, + GenerationConfig, + KGSearchSettings, + VectorSearchSettings, + generate_id_from_label, +) + +from .api.client import R2RClient +from .assembly.builder import R2RBuilder +from .assembly.config import R2RConfig +from .r2r import R2R + + +class R2RExecutionWrapper: + """A demo class for the R2R library.""" + + def __init__( + self, + config_path: Optional[str] = None, + config_name: Optional[str] = "default", + client_mode: bool = True, + base_url="http://localhost:8000", + ): + if config_path and config_name: + raise Exception("Cannot specify both config_path and config_name") + + # Handle fire CLI + if isinstance(client_mode, str): + client_mode = client_mode.lower() == "true" + self.client_mode = client_mode + self.base_url = base_url + + if self.client_mode: + self.client = R2RClient(base_url) + self.app = None + else: + config = ( + R2RConfig.from_json(config_path) + if config_path + else R2RConfig.from_json( + R2RBuilder.CONFIG_OPTIONS[config_name or "default"] + ) + ) + + self.client = None + self.app = R2R(config=config) + + def serve(self, host: str = "0.0.0.0", port: int = 8000): + if not self.client_mode: + self.app.serve(host, port) + else: + raise ValueError( + "Serve method is only available when `client_mode=False`." + ) + + def _parse_metadata_string(metadata_string: str) -> list[dict]: + """ + Convert a string representation of metadata into a list of dictionaries. + + The input string can be in one of two formats: + 1. JSON array of objects: '[{"key": "value"}, {"key2": "value2"}]' + 2. Python-like list of dictionaries: "[{'key': 'value'}, {'key2': 'value2'}]" + + Args: + metadata_string (str): The string representation of metadata. + + Returns: + list[dict]: A list of dictionaries representing the metadata. + + Raises: + ValueError: If the string cannot be parsed into a list of dictionaries. + """ + if not metadata_string: + return [] + + try: + # First, try to parse as JSON + return json.loads(metadata_string) + except json.JSONDecodeError as e: + try: + # If JSON parsing fails, try to evaluate as a Python literal + result = ast.literal_eval(metadata_string) + if not isinstance(result, list) or not all( + isinstance(item, dict) for item in result + ): + raise ValueError( + "The string does not represent a list of dictionaries" + ) from e + return result + except (ValueError, SyntaxError) as exc: + raise ValueError( + "Unable to parse the metadata string. " + "Please ensure it's a valid JSON array or Python list of dictionaries." + ) from exc + + def ingest_files( + self, + file_paths: list[str], + metadatas: Optional[list[dict]] = None, + document_ids: Optional[list[Union[uuid.UUID, str]]] = None, + versions: Optional[list[str]] = None, + ): + if isinstance(file_paths, str): + file_paths = list(file_paths.split(",")) + if isinstance(metadatas, str): + metadatas = self._parse_metadata_string(metadatas) + if isinstance(document_ids, str): + document_ids = list(document_ids.split(",")) + if isinstance(versions, str): + versions = list(versions.split(",")) + + all_file_paths = [] + for path in file_paths: + if os.path.isdir(path): + for root, _, files in os.walk(path): + all_file_paths.extend( + os.path.join(root, file) for file in files + ) + else: + all_file_paths.append(path) + + if not document_ids: + document_ids = [ + generate_id_from_label(os.path.basename(file_path)) + for file_path in all_file_paths + ] + + files = [ + UploadFile( + filename=os.path.basename(file_path), + file=open(file_path, "rb"), + ) + for file_path in all_file_paths + ] + + for file in files: + file.file.seek(0, 2) + file.size = file.file.tell() + file.file.seek(0) + + try: + if self.client_mode: + return self.client.ingest_files( + file_paths=all_file_paths, + document_ids=document_ids, + metadatas=metadatas, + versions=versions, + monitor=True, + )["results"] + else: + return self.app.ingest_files( + files=files, + document_ids=document_ids, + metadatas=metadatas, + versions=versions, + ) + finally: + for file in files: + file.file.close() + + def update_files( + self, + file_paths: list[str], + document_ids: list[str], + metadatas: Optional[list[dict]] = None, + ): + if isinstance(file_paths, str): + file_paths = list(file_paths.split(",")) + if isinstance(metadatas, str): + metadatas = self._parse_metadata_string(metadatas) + if isinstance(document_ids, str): + document_ids = list(document_ids.split(",")) + + if self.client_mode: + return self.client.update_files( + file_paths=file_paths, + document_ids=document_ids, + metadatas=metadatas, + monitor=True, + )["results"] + else: + files = [ + UploadFile( + filename=file_path, + file=open(file_path, "rb"), + ) + for file_path in file_paths + ] + return self.app.update_files( + files=files, document_ids=document_ids, metadatas=metadatas + ) + + def search( + self, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict] = None, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg_search: bool = False, + kg_agent_generation_config: Optional[dict] = None, + ): + if self.client_mode: + return self.client.search( + query, + use_vector_search, + search_filters, + search_limit, + do_hybrid_search, + use_kg_search, + kg_agent_generation_config, + )["results"] + else: + return self.app.search( + query, + VectorSearchSettings( + use_vector_search=use_vector_search, + search_filters=search_filters or {}, + search_limit=search_limit, + do_hybrid_search=do_hybrid_search, + ), + KGSearchSettings( + use_kg_search=use_kg_search, + agent_generation_config=GenerationConfig( + **(kg_agent_generation_config or {}) + ), + ), + ) + + def rag( + self, + query: str, + use_vector_search: bool = True, + search_filters: Optional[dict] = None, + search_limit: int = 10, + do_hybrid_search: bool = False, + use_kg_search: bool = False, + kg_agent_generation_config: Optional[dict] = None, + stream: bool = False, + rag_generation_config: Optional[dict] = None, + ): + if self.client_mode: + response = self.client.rag( + query=query, + use_vector_search=use_vector_search, + search_filters=search_filters or {}, + search_limit=search_limit, + do_hybrid_search=do_hybrid_search, + use_kg_search=use_kg_search, + kg_agent_generation_config=kg_agent_generation_config, + rag_generation_config=rag_generation_config, + ) + if not stream: + response = response["results"] + return response + else: + return response + else: + response = self.app.rag( + query, + vector_search_settings=VectorSearchSettings( + use_vector_search=use_vector_search, + search_filters=search_filters or {}, + search_limit=search_limit, + do_hybrid_search=do_hybrid_search, + ), + kg_search_settings=KGSearchSettings( + use_kg_search=use_kg_search, + agent_generation_config=GenerationConfig( + **(kg_agent_generation_config or {}) + ), + ), + rag_generation_config=GenerationConfig( + **(rag_generation_config or {}) + ), + ) + if not stream: + return response + else: + + async def async_generator(): + async for chunk in response: + yield chunk + + def sync_generator(): + try: + loop = asyncio.get_event_loop() + async_gen = async_generator() + while True: + try: + yield loop.run_until_complete( + async_gen.__anext__() + ) + except StopAsyncIteration: + break + except Exception: + pass + + return sync_generator() + + def documents_overview( + self, + document_ids: Optional[list[str]] = None, + user_ids: Optional[list[str]] = None, + ): + if self.client_mode: + return self.client.documents_overview(document_ids, user_ids)[ + "results" + ] + else: + return self.app.documents_overview(document_ids, user_ids) + + def delete( + self, + keys: list[str], + values: list[str], + ): + if self.client_mode: + return self.client.delete(keys, values)["results"] + else: + return self.app.delete(keys, values) + + def logs(self, log_type_filter: Optional[str] = None): + if self.client_mode: + return self.client.logs(log_type_filter)["results"] + else: + return self.app.logs(log_type_filter) + + def document_chunks(self, document_id: str): + doc_uuid = uuid.UUID(document_id) + if self.client_mode: + return self.client.document_chunks(doc_uuid)["results"] + else: + return self.app.document_chunks(doc_uuid) + + def app_settings(self): + if self.client_mode: + return self.client.app_settings() + else: + return self.app.app_settings() + + def users_overview(self, user_ids: Optional[list[uuid.UUID]] = None): + if self.client_mode: + return self.client.users_overview(user_ids)["results"] + else: + return self.app.users_overview(user_ids) + + def analytics( + self, + filters: Optional[str] = None, + analysis_types: Optional[str] = None, + ): + filter_criteria = FilterCriteria(filters=filters) + analysis_types = AnalysisTypes(analysis_types=analysis_types) + + if self.client_mode: + return self.client.analytics( + filter_criteria=filter_criteria.model_dump(), + analysis_types=analysis_types.model_dump(), + )["results"] + else: + return self.app.analytics( + filter_criteria=filter_criteria, analysis_types=analysis_types + ) + + def ingest_sample_file(self, no_media: bool = True, option: int = 0): + from r2r.examples.scripts.sample_data_ingestor import ( + SampleDataIngestor, + ) + + """Ingest the first sample file into R2R.""" + sample_ingestor = SampleDataIngestor(self) + return sample_ingestor.ingest_sample_file( + no_media=no_media, option=option + ) + + def ingest_sample_files(self, no_media: bool = True): + from r2r.examples.scripts.sample_data_ingestor import ( + SampleDataIngestor, + ) + + """Ingest the first sample file into R2R.""" + sample_ingestor = SampleDataIngestor(self) + return sample_ingestor.ingest_sample_files(no_media=no_media) + + def inspect_knowledge_graph(self, limit: int = 100) -> str: + if self.client_mode: + return self.client.inspect_knowledge_graph(limit)["results"] + else: + return self.engine.inspect_knowledge_graph(limit) + + def health(self) -> str: + if self.client_mode: + return self.client.health() + else: + pass + + def get_app(self): + if not self.client_mode: + return self.app.app.app + else: + raise Exception( + "`get_app` method is only available when running with `client_mode=False`." + ) + + +if __name__ == "__main__": + import fire + + fire.Fire(R2RExecutionWrapper) diff --git a/R2R/r2r/main/r2r.py b/R2R/r2r/main/r2r.py new file mode 100755 index 00000000..2d8601b2 --- /dev/null +++ b/R2R/r2r/main/r2r.py @@ -0,0 +1,51 @@ +from typing import Optional + +from .app import R2RApp +from .assembly.config import R2RConfig +from .engine import R2REngine + + +class R2R: + engine: R2REngine + app: R2RApp + + def __init__( + self, + engine: Optional[R2REngine] = None, + app: Optional[R2RApp] = None, + config: Optional[R2RConfig] = None, + from_config: Optional[str] = None, + *args, + **kwargs + ): + if engine and app: + self.engine = engine + self.app = app + elif (config or from_config) or ( + config is None and from_config is None + ): + from .assembly.builder import R2RBuilder + + # Handle the case where 'from_config' is None and 'config' is None + if not config and not from_config: + from_config = "default" + builder = R2RBuilder( + config=config, + from_config=from_config, + ) + built = builder.build() + self.engine = built.engine + self.app = built.app + else: + raise ValueError( + "Must provide either 'engine' and 'app', or 'config'/'from_config' to build the R2R object." + ) + + def __getattr__(self, name): + # Check if the attribute name is 'app' and return it directly + if name == "app": + return self.app + elif name == "serve": + return self.app.serve + # Otherwise, delegate to the engine + return getattr(self.engine, name) diff --git a/R2R/r2r/main/services/__init__.py b/R2R/r2r/main/services/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/main/services/__init__.py diff --git a/R2R/r2r/main/services/base.py b/R2R/r2r/main/services/base.py new file mode 100755 index 00000000..02c0675d --- /dev/null +++ b/R2R/r2r/main/services/base.py @@ -0,0 +1,22 @@ +from abc import ABC + +from r2r.base import KVLoggingSingleton, RunManager + +from ..abstractions import R2RPipelines, R2RProviders +from ..assembly.config import R2RConfig + + +class Service(ABC): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + pipelines: R2RPipelines, + run_manager: RunManager, + logging_connection: KVLoggingSingleton, + ): + self.config = config + self.providers = providers + self.pipelines = pipelines + self.run_manager = run_manager + self.logging_connection = logging_connection diff --git a/R2R/r2r/main/services/ingestion_service.py b/R2R/r2r/main/services/ingestion_service.py new file mode 100755 index 00000000..5677807a --- /dev/null +++ b/R2R/r2r/main/services/ingestion_service.py @@ -0,0 +1,505 @@ +import json +import logging +import uuid +from collections import defaultdict +from datetime import datetime +from typing import Any, Optional, Union + +from fastapi import Form, UploadFile + +from r2r.base import ( + Document, + DocumentInfo, + DocumentType, + KVLoggingSingleton, + R2RDocumentProcessingError, + R2RException, + RunManager, + generate_id_from_label, + increment_version, + to_async_generator, +) +from r2r.telemetry.telemetry_decorator import telemetry_event + +from ..abstractions import R2RPipelines, R2RProviders +from ..api.requests import R2RIngestFilesRequest, R2RUpdateFilesRequest +from ..assembly.config import R2RConfig +from .base import Service + +logger = logging.getLogger(__name__) +MB_CONVERSION_FACTOR = 1024 * 1024 + + +class IngestionService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + pipelines: R2RPipelines, + run_manager: RunManager, + logging_connection: KVLoggingSingleton, + ): + super().__init__( + config, providers, pipelines, run_manager, logging_connection + ) + + def _file_to_document( + self, file: UploadFile, document_id: uuid.UUID, metadata: dict + ) -> Document: + file_extension = file.filename.split(".")[-1].lower() + if file_extension.upper() not in DocumentType.__members__: + raise R2RException( + status_code=415, + message=f"'{file_extension}' is not a valid DocumentType.", + ) + + document_title = ( + metadata.get("title", None) or file.filename.split("/")[-1] + ) + metadata["title"] = document_title + + return Document( + id=document_id, + type=DocumentType[file_extension.upper()], + data=file.file.read(), + metadata=metadata, + ) + + @telemetry_event("IngestDocuments") + async def ingest_documents( + self, + documents: list[Document], + versions: Optional[list[str]] = None, + *args: Any, + **kwargs: Any, + ): + if len(documents) == 0: + raise R2RException( + status_code=400, message="No documents provided for ingestion." + ) + + document_infos = [] + skipped_documents = [] + processed_documents = {} + duplicate_documents = defaultdict(list) + + existing_document_info = { + doc_info.document_id: doc_info + for doc_info in self.providers.vector_db.get_documents_overview() + } + + for iteration, document in enumerate(documents): + version = versions[iteration] if versions else "v0" + + # Check for duplicates within the current batch + if document.id in processed_documents: + duplicate_documents[document.id].append( + document.metadata.get("title", str(document.id)) + ) + continue + + if ( + document.id in existing_document_info + and existing_document_info[document.id].version == version + and existing_document_info[document.id].status == "success" + ): + logger.error( + f"Document with ID {document.id} was already successfully processed." + ) + if len(documents) == 1: + raise R2RException( + status_code=409, + message=f"Document with ID {document.id} was already successfully processed.", + ) + skipped_documents.append( + ( + document.id, + document.metadata.get("title", None) + or str(document.id), + ) + ) + continue + + now = datetime.now() + document_infos.append( + DocumentInfo( + document_id=document.id, + version=version, + size_in_bytes=len(document.data), + metadata=document.metadata.copy(), + title=document.metadata.get("title", str(document.id)), + user_id=document.metadata.get("user_id", None), + created_at=now, + updated_at=now, + status="processing", # Set initial status to `processing` + ) + ) + + processed_documents[document.id] = document.metadata.get( + "title", str(document.id) + ) + + if duplicate_documents: + duplicate_details = [ + f"{doc_id}: {', '.join(titles)}" + for doc_id, titles in duplicate_documents.items() + ] + warning_message = f"Duplicate documents detected: {'; '.join(duplicate_details)}. These duplicates were skipped." + raise R2RException(status_code=418, message=warning_message) + + if skipped_documents and len(skipped_documents) == len(documents): + logger.error("All provided documents already exist.") + raise R2RException( + status_code=409, + message="All provided documents already exist. Use the `update_documents` endpoint instead to update these documents.", + ) + + # Insert pending document infos + self.providers.vector_db.upsert_documents_overview(document_infos) + ingestion_results = await self.pipelines.ingestion_pipeline.run( + input=to_async_generator( + [ + doc + for doc in documents + if doc.id + not in [skipped[0] for skipped in skipped_documents] + ] + ), + versions=[info.version for info in document_infos], + run_manager=self.run_manager, + *args, + **kwargs, + ) + + return await self._process_ingestion_results( + ingestion_results, + document_infos, + skipped_documents, + processed_documents, + ) + + @telemetry_event("IngestFiles") + async def ingest_files( + self, + files: list[UploadFile], + metadatas: Optional[list[dict]] = None, + document_ids: Optional[list[uuid.UUID]] = None, + versions: Optional[list[str]] = None, + *args: Any, + **kwargs: Any, + ): + if not files: + raise R2RException( + status_code=400, message="No files provided for ingestion." + ) + + try: + documents = [] + for iteration, file in enumerate(files): + logger.info(f"Processing file: {file.filename}") + if ( + file.size + > self.config.app.get("max_file_size_in_mb", 32) + * MB_CONVERSION_FACTOR + ): + raise R2RException( + status_code=413, + message=f"File size exceeds maximum allowed size: {file.filename}", + ) + if not file.filename: + raise R2RException( + status_code=400, message="File name not provided." + ) + + document_metadata = metadatas[iteration] if metadatas else {} + document_id = ( + document_ids[iteration] + if document_ids + else generate_id_from_label(file.filename.split("/")[-1]) + ) + + document = self._file_to_document( + file, document_id, document_metadata + ) + documents.append(document) + + return await self.ingest_documents( + documents, versions, *args, **kwargs + ) + + finally: + for file in files: + file.file.close() + + @telemetry_event("UpdateFiles") + async def update_files( + self, + files: list[UploadFile], + document_ids: list[uuid.UUID], + metadatas: Optional[list[dict]] = None, + *args: Any, + **kwargs: Any, + ): + if not files: + raise R2RException( + status_code=400, message="No files provided for update." + ) + + try: + if len(document_ids) != len(files): + raise R2RException( + status_code=400, + message="Number of ids does not match number of files.", + ) + + documents_overview = await self._documents_overview( + document_ids=document_ids + ) + if len(documents_overview) != len(files): + raise R2RException( + status_code=404, + message="One or more documents was not found.", + ) + + documents = [] + new_versions = [] + + for it, (file, doc_id, doc_info) in enumerate( + zip(files, document_ids, documents_overview) + ): + if not doc_info: + raise R2RException( + status_code=404, + message=f"Document with id {doc_id} not found.", + ) + + new_version = increment_version(doc_info.version) + new_versions.append(new_version) + + updated_metadata = ( + metadatas[it] if metadatas else doc_info.metadata + ) + updated_metadata["title"] = ( + updated_metadata.get("title", None) + or file.filename.split("/")[-1] + ) + + document = self._file_to_document( + file, doc_id, updated_metadata + ) + documents.append(document) + + ingestion_results = await self.ingest_documents( + documents, versions=new_versions, *args, **kwargs + ) + + for doc_id, old_version in zip( + document_ids, + [doc_info.version for doc_info in documents_overview], + ): + await self._delete( + ["document_id", "version"], [str(doc_id), old_version] + ) + self.providers.vector_db.delete_from_documents_overview( + doc_id, old_version + ) + + return ingestion_results + + finally: + for file in files: + file.file.close() + + async def _process_ingestion_results( + self, + ingestion_results: dict, + document_infos: list[DocumentInfo], + skipped_documents: list[tuple[str, str]], + processed_documents: dict, + ): + skipped_ids = [ele[0] for ele in skipped_documents] + failed_ids = [] + successful_ids = [] + + results = {} + if ingestion_results["embedding_pipeline_output"]: + results = { + k: v for k, v in ingestion_results["embedding_pipeline_output"] + } + for doc_id, error in results.items(): + if isinstance(error, R2RDocumentProcessingError): + logger.error( + f"Error processing document with ID {error.document_id}: {error.message}" + ) + failed_ids.append(error.document_id) + elif isinstance(error, Exception): + logger.error(f"Error processing document: {error}") + failed_ids.append(doc_id) + else: + successful_ids.append(doc_id) + + documents_to_upsert = [] + for document_info in document_infos: + if document_info.document_id not in skipped_ids: + if document_info.document_id in failed_ids: + document_info.status = "failure" + elif document_info.document_id in successful_ids: + document_info.status = "success" + documents_to_upsert.append(document_info) + + if documents_to_upsert: + self.providers.vector_db.upsert_documents_overview( + documents_to_upsert + ) + + results = { + "processed_documents": [ + f"Document '{processed_documents[document_id]}' processed successfully." + for document_id in successful_ids + ], + "failed_documents": [ + f"Document '{processed_documents[document_id]}': {results[document_id]}" + for document_id in failed_ids + ], + "skipped_documents": [ + f"Document '{filename}' skipped since it already exists." + for _, filename in skipped_documents + ], + } + + # TODO - Clean up logging for document parse results + run_ids = list(self.run_manager.run_info.keys()) + if run_ids: + run_id = run_ids[0] + for key in results: + if key in ["processed_documents", "failed_documents"]: + for value in results[key]: + await self.logging_connection.log( + log_id=run_id, + key="document_parse_result", + value=value, + ) + return results + + @staticmethod + def parse_ingest_files_form_data( + metadatas: Optional[str] = Form(None), + document_ids: str = Form(None), + versions: Optional[str] = Form(None), + ) -> R2RIngestFilesRequest: + try: + parsed_metadatas = ( + json.loads(metadatas) + if metadatas and metadatas != "null" + else None + ) + if parsed_metadatas is not None and not isinstance( + parsed_metadatas, list + ): + raise ValueError("metadatas must be a list of dictionaries") + + parsed_document_ids = ( + json.loads(document_ids) + if document_ids and document_ids != "null" + else None + ) + if parsed_document_ids is not None: + parsed_document_ids = [ + uuid.UUID(doc_id) for doc_id in parsed_document_ids + ] + + parsed_versions = ( + json.loads(versions) + if versions and versions != "null" + else None + ) + + request_data = { + "metadatas": parsed_metadatas, + "document_ids": parsed_document_ids, + "versions": parsed_versions, + } + return R2RIngestFilesRequest(**request_data) + except json.JSONDecodeError as e: + raise R2RException( + status_code=400, message=f"Invalid JSON in form data: {e}" + ) + except ValueError as e: + raise R2RException(status_code=400, message=str(e)) + except Exception as e: + raise R2RException( + status_code=400, message=f"Error processing form data: {e}" + ) + + @staticmethod + def parse_update_files_form_data( + metadatas: Optional[str] = Form(None), + document_ids: str = Form(...), + ) -> R2RUpdateFilesRequest: + try: + parsed_metadatas = ( + json.loads(metadatas) + if metadatas and metadatas != "null" + else None + ) + if parsed_metadatas is not None and not isinstance( + parsed_metadatas, list + ): + raise ValueError("metadatas must be a list of dictionaries") + + if not document_ids or document_ids == "null": + raise ValueError("document_ids is required and cannot be null") + + parsed_document_ids = json.loads(document_ids) + if not isinstance(parsed_document_ids, list): + raise ValueError("document_ids must be a list") + parsed_document_ids = [ + uuid.UUID(doc_id) for doc_id in parsed_document_ids + ] + + request_data = { + "metadatas": parsed_metadatas, + "document_ids": parsed_document_ids, + } + return R2RUpdateFilesRequest(**request_data) + except json.JSONDecodeError as e: + raise R2RException( + status_code=400, message=f"Invalid JSON in form data: {e}" + ) + except ValueError as e: + raise R2RException(status_code=400, message=str(e)) + except Exception as e: + raise R2RException( + status_code=400, message=f"Error processing form data: {e}" + ) + + # TODO - Move to mgmt service for document info, delete, post orchestration buildout + async def _documents_overview( + self, + document_ids: Optional[list[uuid.UUID]] = None, + user_ids: Optional[list[uuid.UUID]] = None, + *args: Any, + **kwargs: Any, + ): + return self.providers.vector_db.get_documents_overview( + filter_document_ids=( + [str(ele) for ele in document_ids] if document_ids else None + ), + filter_user_ids=( + [str(ele) for ele in user_ids] if user_ids else None + ), + ) + + async def _delete( + self, keys: list[str], values: list[Union[bool, int, str]] + ): + logger.info( + f"Deleting documents which match on these keys and values: ({keys}, {values})" + ) + + ids = self.providers.vector_db.delete_by_metadata(keys, values) + if not ids: + raise R2RException( + status_code=404, message="No entries found for deletion." + ) + return "Entries deleted successfully." diff --git a/R2R/r2r/main/services/management_service.py b/R2R/r2r/main/services/management_service.py new file mode 100755 index 00000000..00f1f56e --- /dev/null +++ b/R2R/r2r/main/services/management_service.py @@ -0,0 +1,385 @@ +import logging +import uuid +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, Union + +from r2r.base import ( + AnalysisTypes, + FilterCriteria, + KVLoggingSingleton, + LogProcessor, + R2RException, + RunManager, +) +from r2r.telemetry.telemetry_decorator import telemetry_event + +from ..abstractions import R2RPipelines, R2RProviders +from ..assembly.config import R2RConfig +from .base import Service + +logger = logging.getLogger(__name__) + + +class ManagementService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + pipelines: R2RPipelines, + run_manager: RunManager, + logging_connection: KVLoggingSingleton, + ): + super().__init__( + config, providers, pipelines, run_manager, logging_connection + ) + + @telemetry_event("UpdatePrompt") + async def update_prompt( + self, + name: str, + template: Optional[str] = None, + input_types: Optional[dict[str, str]] = {}, + *args, + **kwargs, + ): + self.providers.prompt.update_prompt(name, template, input_types) + return f"Prompt '{name}' added successfully." + + @telemetry_event("Logs") + async def alogs( + self, + log_type_filter: Optional[str] = None, + max_runs_requested: int = 100, + *args: Any, + **kwargs: Any, + ): + if self.logging_connection is None: + raise R2RException( + status_code=404, message="Logging provider not found." + ) + if ( + self.config.app.get("max_logs_per_request", 100) + > max_runs_requested + ): + raise R2RException( + status_code=400, + message="Max runs requested exceeds the limit.", + ) + + run_info = await self.logging_connection.get_run_info( + limit=max_runs_requested, + log_type_filter=log_type_filter, + ) + run_ids = [run.run_id for run in run_info] + if len(run_ids) == 0: + return [] + logs = await self.logging_connection.get_logs(run_ids) + # Aggregate logs by run_id and include run_type + aggregated_logs = [] + + for run in run_info: + run_logs = [log for log in logs if log["log_id"] == run.run_id] + entries = [ + {"key": log["key"], "value": log["value"]} for log in run_logs + ][ + ::-1 + ] # Reverse order so that earliest logged values appear first. + aggregated_logs.append( + { + "run_id": run.run_id, + "run_type": run.log_type, + "entries": entries, + } + ) + + return aggregated_logs + + @telemetry_event("Analytics") + async def aanalytics( + self, + filter_criteria: FilterCriteria, + analysis_types: AnalysisTypes, + *args, + **kwargs, + ): + run_info = await self.logging_connection.get_run_info(limit=100) + run_ids = [info.run_id for info in run_info] + + if not run_ids: + return { + "analytics_data": "No logs found.", + "filtered_logs": {}, + } + logs = await self.logging_connection.get_logs(run_ids=run_ids) + + filters = {} + if filter_criteria.filters: + for key, value in filter_criteria.filters.items(): + filters[key] = lambda log, value=value: ( + any( + entry.get("key") == value + for entry in log.get("entries", []) + ) + if "entries" in log + else log.get("key") == value + ) + + log_processor = LogProcessor(filters) + for log in logs: + if "entries" in log and isinstance(log["entries"], list): + log_processor.process_log(log) + elif "key" in log: + log_processor.process_log(log) + else: + logger.warning( + f"Skipping log due to missing or malformed 'entries': {log}" + ) + + filtered_logs = dict(log_processor.populations.items()) + results = {"filtered_logs": filtered_logs} + + if analysis_types and analysis_types.analysis_types: + for ( + filter_key, + analysis_config, + ) in analysis_types.analysis_types.items(): + if filter_key in filtered_logs: + analysis_type = analysis_config[0] + if analysis_type == "bar_chart": + extract_key = analysis_config[1] + results[filter_key] = ( + AnalysisTypes.generate_bar_chart_data( + filtered_logs[filter_key], extract_key + ) + ) + elif analysis_type == "basic_statistics": + extract_key = analysis_config[1] + results[filter_key] = ( + AnalysisTypes.calculate_basic_statistics( + filtered_logs[filter_key], extract_key + ) + ) + elif analysis_type == "percentile": + extract_key = analysis_config[1] + percentile = int(analysis_config[2]) + results[filter_key] = ( + AnalysisTypes.calculate_percentile( + filtered_logs[filter_key], + extract_key, + percentile, + ) + ) + else: + logger.warning( + f"Unknown analysis type for filter key '{filter_key}': {analysis_type}" + ) + + return results + + @telemetry_event("AppSettings") + async def aapp_settings(self, *args: Any, **kwargs: Any): + prompts = self.providers.prompt.get_all_prompts() + return { + "config": self.config.to_json(), + "prompts": { + name: prompt.dict() for name, prompt in prompts.items() + }, + } + + @telemetry_event("UsersOverview") + async def ausers_overview( + self, + user_ids: Optional[list[uuid.UUID]] = None, + *args, + **kwargs, + ): + return self.providers.vector_db.get_users_overview( + [str(ele) for ele in user_ids] if user_ids else None + ) + + @telemetry_event("Delete") + async def delete( + self, + keys: list[str], + values: list[Union[bool, int, str]], + *args, + **kwargs, + ): + metadata = ", ".join( + f"{key}={value}" for key, value in zip(keys, values) + ) + values = [str(value) for value in values] + logger.info(f"Deleting entries with metadata: {metadata}") + ids = self.providers.vector_db.delete_by_metadata(keys, values) + if not ids: + raise R2RException( + status_code=404, message="No entries found for deletion." + ) + for id in ids: + self.providers.vector_db.delete_from_documents_overview(id) + return f"Documents {ids} deleted successfully." + + @telemetry_event("DocumentsOverview") + async def adocuments_overview( + self, + document_ids: Optional[list[uuid.UUID]] = None, + user_ids: Optional[list[uuid.UUID]] = None, + *args: Any, + **kwargs: Any, + ): + return self.providers.vector_db.get_documents_overview( + filter_document_ids=( + [str(ele) for ele in document_ids] if document_ids else None + ), + filter_user_ids=( + [str(ele) for ele in user_ids] if user_ids else None + ), + ) + + @telemetry_event("DocumentChunks") + async def document_chunks( + self, + document_id: uuid.UUID, + *args, + **kwargs, + ): + return self.providers.vector_db.get_document_chunks(str(document_id)) + + @telemetry_event("UsersOverview") + async def users_overview( + self, + user_ids: Optional[list[uuid.UUID]], + *args, + **kwargs, + ): + return self.providers.vector_db.get_users_overview( + [str(ele) for ele in user_ids] + ) + + @telemetry_event("InspectKnowledgeGraph") + async def inspect_knowledge_graph( + self, limit=10000, *args: Any, **kwargs: Any + ): + if self.providers.kg is None: + raise R2RException( + status_code=404, message="Knowledge Graph provider not found." + ) + + rel_query = f""" + MATCH (n1)-[r]->(n2) + RETURN n1.id AS subject, type(r) AS relation, n2.id AS object + LIMIT {limit} + """ + + try: + with self.providers.kg.client.session( + database=self.providers.kg._database + ) as session: + results = session.run(rel_query) + relationships = [ + (record["subject"], record["relation"], record["object"]) + for record in results + ] + + # Create graph representation and group relationships + graph, grouped_relationships = self.process_relationships( + relationships + ) + + # Generate output + output = self.generate_output(grouped_relationships, graph) + + return "\n".join(output) + + except Exception as e: + logger.error(f"Error printing relationships: {str(e)}") + raise R2RException( + status_code=500, + message=f"An error occurred while fetching relationships: {str(e)}", + ) + + def process_relationships( + self, relationships: List[Tuple[str, str, str]] + ) -> Tuple[Dict[str, List[str]], Dict[str, Dict[str, List[str]]]]: + graph = defaultdict(list) + grouped = defaultdict(lambda: defaultdict(list)) + for subject, relation, obj in relationships: + graph[subject].append(obj) + grouped[subject][relation].append(obj) + if obj not in graph: + graph[obj] = [] + return dict(graph), dict(grouped) + + def generate_output( + self, + grouped_relationships: Dict[str, Dict[str, List[str]]], + graph: Dict[str, List[str]], + ) -> List[str]: + output = [] + + # Print grouped relationships + for subject, relations in grouped_relationships.items(): + output.append(f"\n== {subject} ==") + for relation, objects in relations.items(): + output.append(f" {relation}:") + for obj in objects: + output.append(f" - {obj}") + + # Print basic graph statistics + output.append("\n== Graph Statistics ==") + output.append(f"Number of nodes: {len(graph)}") + output.append( + f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}" + ) + output.append( + f"Number of connected components: {self.count_connected_components(graph)}" + ) + + # Find central nodes + central_nodes = self.get_central_nodes(graph) + output.append("\n== Most Central Nodes ==") + for node, centrality in central_nodes: + output.append(f" {node}: {centrality:.4f}") + + return output + + def count_connected_components(self, graph: Dict[str, List[str]]) -> int: + visited = set() + components = 0 + + def dfs(node): + visited.add(node) + for neighbor in graph[node]: + if neighbor not in visited: + dfs(neighbor) + + for node in graph: + if node not in visited: + dfs(node) + components += 1 + + return components + + def get_central_nodes( + self, graph: Dict[str, List[str]] + ) -> List[Tuple[str, float]]: + degree = {node: len(neighbors) for node, neighbors in graph.items()} + total_nodes = len(graph) + centrality = { + node: deg / (total_nodes - 1) for node, deg in degree.items() + } + return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5] + + @telemetry_event("AppSettings") + async def app_settings( + self, + *args, + **kwargs, + ): + prompts = self.providers.prompt.get_all_prompts() + return { + "config": self.config.to_json(), + "prompts": { + name: prompt.dict() for name, prompt in prompts.items() + }, + } diff --git a/R2R/r2r/main/services/retrieval_service.py b/R2R/r2r/main/services/retrieval_service.py new file mode 100755 index 00000000..c4f6aff5 --- /dev/null +++ b/R2R/r2r/main/services/retrieval_service.py @@ -0,0 +1,207 @@ +import logging +import time +import uuid +from typing import Optional + +from r2r.base import ( + GenerationConfig, + KGSearchSettings, + KVLoggingSingleton, + R2RException, + RunManager, + VectorSearchSettings, + manage_run, + to_async_generator, +) +from r2r.pipes import EvalPipe +from r2r.telemetry.telemetry_decorator import telemetry_event + +from ..abstractions import R2RPipelines, R2RProviders +from ..assembly.config import R2RConfig +from .base import Service + +logger = logging.getLogger(__name__) + + +class RetrievalService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + pipelines: R2RPipelines, + run_manager: RunManager, + logging_connection: KVLoggingSingleton, + ): + super().__init__( + config, providers, pipelines, run_manager, logging_connection + ) + + @telemetry_event("Search") + async def search( + self, + query: str, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + kg_search_settings: KGSearchSettings = KGSearchSettings(), + *args, + **kwargs, + ): + async with manage_run(self.run_manager, "search_app") as run_id: + t0 = time.time() + + if ( + kg_search_settings.use_kg_search + and self.config.kg.provider is None + ): + raise R2RException( + status_code=400, + message="Knowledge Graph search is not enabled in the configuration.", + ) + + if ( + vector_search_settings.use_vector_search + and self.config.vector_database.provider is None + ): + raise R2RException( + status_code=400, + message="Vector search is not enabled in the configuration.", + ) + + # TODO - Remove these transforms once we have a better way to handle this + for filter, value in vector_search_settings.search_filters.items(): + if isinstance(value, uuid.UUID): + vector_search_settings.search_filters[filter] = str(value) + + results = await self.pipelines.search_pipeline.run( + input=to_async_generator([query]), + vector_search_settings=vector_search_settings, + kg_search_settings=kg_search_settings, + run_manager=self.run_manager, + *args, + **kwargs, + ) + + t1 = time.time() + latency = f"{t1 - t0:.2f}" + + await self.logging_connection.log( + log_id=run_id, + key="search_latency", + value=latency, + is_info_log=False, + ) + + return results.dict() + + @telemetry_event("RAG") + async def rag( + self, + query: str, + rag_generation_config: GenerationConfig, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + kg_search_settings: KGSearchSettings = KGSearchSettings(), + *args, + **kwargs, + ): + async with manage_run(self.run_manager, "rag_app") as run_id: + try: + t0 = time.time() + + # TODO - Remove these transforms once we have a better way to handle this + for ( + filter, + value, + ) in vector_search_settings.search_filters.items(): + if isinstance(value, uuid.UUID): + vector_search_settings.search_filters[filter] = str( + value + ) + + if rag_generation_config.stream: + t1 = time.time() + latency = f"{t1 - t0:.2f}" + + await self.logging_connection.log( + log_id=run_id, + key="rag_generation_latency", + value=latency, + is_info_log=False, + ) + + async def stream_response(): + async with manage_run(self.run_manager, "arag"): + async for ( + chunk + ) in await self.pipelines.streaming_rag_pipeline.run( + input=to_async_generator([query]), + run_manager=self.run_manager, + vector_search_settings=vector_search_settings, + kg_search_settings=kg_search_settings, + rag_generation_config=rag_generation_config, + ): + yield chunk + + return stream_response() + + results = await self.pipelines.rag_pipeline.run( + input=to_async_generator([query]), + run_manager=self.run_manager, + vector_search_settings=vector_search_settings, + kg_search_settings=kg_search_settings, + rag_generation_config=rag_generation_config, + *args, + **kwargs, + ) + + t1 = time.time() + latency = f"{t1 - t0:.2f}" + + await self.logging_connection.log( + log_id=run_id, + key="rag_generation_latency", + value=latency, + is_info_log=False, + ) + + if len(results) == 0: + raise R2RException( + status_code=404, message="No results found" + ) + if len(results) > 1: + logger.warning( + f"Multiple results found for query: {query}" + ) + # unpack the first result + return results[0] + + except Exception as e: + logger.error(f"Pipeline error: {str(e)}") + if "NoneType" in str(e): + raise R2RException( + status_code=502, + message="Ollama server not reachable or returned an invalid response", + ) + raise R2RException( + status_code=500, message="Internal Server Error" + ) + + @telemetry_event("Evaluate") + async def evaluate( + self, + query: str, + context: str, + completion: str, + eval_generation_config: Optional[GenerationConfig], + *args, + **kwargs, + ): + eval_payload = EvalPipe.EvalPayload( + query=query, + context=context, + completion=completion, + ) + result = await self.eval_pipeline.run( + input=to_async_generator([eval_payload]), + run_manager=self.run_manager, + eval_generation_config=eval_generation_config, + ) + return result |