diff options
Diffstat (limited to 'R2R/r2r/main/services')
-rwxr-xr-x | R2R/r2r/main/services/__init__.py | 0 | ||||
-rwxr-xr-x | R2R/r2r/main/services/base.py | 22 | ||||
-rwxr-xr-x | R2R/r2r/main/services/ingestion_service.py | 505 | ||||
-rwxr-xr-x | R2R/r2r/main/services/management_service.py | 385 | ||||
-rwxr-xr-x | R2R/r2r/main/services/retrieval_service.py | 207 |
5 files changed, 1119 insertions, 0 deletions
diff --git a/R2R/r2r/main/services/__init__.py b/R2R/r2r/main/services/__init__.py new file mode 100755 index 00000000..e69de29b --- /dev/null +++ b/R2R/r2r/main/services/__init__.py diff --git a/R2R/r2r/main/services/base.py b/R2R/r2r/main/services/base.py new file mode 100755 index 00000000..02c0675d --- /dev/null +++ b/R2R/r2r/main/services/base.py @@ -0,0 +1,22 @@ +from abc import ABC + +from r2r.base import KVLoggingSingleton, RunManager + +from ..abstractions import R2RPipelines, R2RProviders +from ..assembly.config import R2RConfig + + +class Service(ABC): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + pipelines: R2RPipelines, + run_manager: RunManager, + logging_connection: KVLoggingSingleton, + ): + self.config = config + self.providers = providers + self.pipelines = pipelines + self.run_manager = run_manager + self.logging_connection = logging_connection diff --git a/R2R/r2r/main/services/ingestion_service.py b/R2R/r2r/main/services/ingestion_service.py new file mode 100755 index 00000000..5677807a --- /dev/null +++ b/R2R/r2r/main/services/ingestion_service.py @@ -0,0 +1,505 @@ +import json +import logging +import uuid +from collections import defaultdict +from datetime import datetime +from typing import Any, Optional, Union + +from fastapi import Form, UploadFile + +from r2r.base import ( + Document, + DocumentInfo, + DocumentType, + KVLoggingSingleton, + R2RDocumentProcessingError, + R2RException, + RunManager, + generate_id_from_label, + increment_version, + to_async_generator, +) +from r2r.telemetry.telemetry_decorator import telemetry_event + +from ..abstractions import R2RPipelines, R2RProviders +from ..api.requests import R2RIngestFilesRequest, R2RUpdateFilesRequest +from ..assembly.config import R2RConfig +from .base import Service + +logger = logging.getLogger(__name__) +MB_CONVERSION_FACTOR = 1024 * 1024 + + +class IngestionService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + pipelines: R2RPipelines, + run_manager: RunManager, + logging_connection: KVLoggingSingleton, + ): + super().__init__( + config, providers, pipelines, run_manager, logging_connection + ) + + def _file_to_document( + self, file: UploadFile, document_id: uuid.UUID, metadata: dict + ) -> Document: + file_extension = file.filename.split(".")[-1].lower() + if file_extension.upper() not in DocumentType.__members__: + raise R2RException( + status_code=415, + message=f"'{file_extension}' is not a valid DocumentType.", + ) + + document_title = ( + metadata.get("title", None) or file.filename.split("/")[-1] + ) + metadata["title"] = document_title + + return Document( + id=document_id, + type=DocumentType[file_extension.upper()], + data=file.file.read(), + metadata=metadata, + ) + + @telemetry_event("IngestDocuments") + async def ingest_documents( + self, + documents: list[Document], + versions: Optional[list[str]] = None, + *args: Any, + **kwargs: Any, + ): + if len(documents) == 0: + raise R2RException( + status_code=400, message="No documents provided for ingestion." + ) + + document_infos = [] + skipped_documents = [] + processed_documents = {} + duplicate_documents = defaultdict(list) + + existing_document_info = { + doc_info.document_id: doc_info + for doc_info in self.providers.vector_db.get_documents_overview() + } + + for iteration, document in enumerate(documents): + version = versions[iteration] if versions else "v0" + + # Check for duplicates within the current batch + if document.id in processed_documents: + duplicate_documents[document.id].append( + document.metadata.get("title", str(document.id)) + ) + continue + + if ( + document.id in existing_document_info + and existing_document_info[document.id].version == version + and existing_document_info[document.id].status == "success" + ): + logger.error( + f"Document with ID {document.id} was already successfully processed." + ) + if len(documents) == 1: + raise R2RException( + status_code=409, + message=f"Document with ID {document.id} was already successfully processed.", + ) + skipped_documents.append( + ( + document.id, + document.metadata.get("title", None) + or str(document.id), + ) + ) + continue + + now = datetime.now() + document_infos.append( + DocumentInfo( + document_id=document.id, + version=version, + size_in_bytes=len(document.data), + metadata=document.metadata.copy(), + title=document.metadata.get("title", str(document.id)), + user_id=document.metadata.get("user_id", None), + created_at=now, + updated_at=now, + status="processing", # Set initial status to `processing` + ) + ) + + processed_documents[document.id] = document.metadata.get( + "title", str(document.id) + ) + + if duplicate_documents: + duplicate_details = [ + f"{doc_id}: {', '.join(titles)}" + for doc_id, titles in duplicate_documents.items() + ] + warning_message = f"Duplicate documents detected: {'; '.join(duplicate_details)}. These duplicates were skipped." + raise R2RException(status_code=418, message=warning_message) + + if skipped_documents and len(skipped_documents) == len(documents): + logger.error("All provided documents already exist.") + raise R2RException( + status_code=409, + message="All provided documents already exist. Use the `update_documents` endpoint instead to update these documents.", + ) + + # Insert pending document infos + self.providers.vector_db.upsert_documents_overview(document_infos) + ingestion_results = await self.pipelines.ingestion_pipeline.run( + input=to_async_generator( + [ + doc + for doc in documents + if doc.id + not in [skipped[0] for skipped in skipped_documents] + ] + ), + versions=[info.version for info in document_infos], + run_manager=self.run_manager, + *args, + **kwargs, + ) + + return await self._process_ingestion_results( + ingestion_results, + document_infos, + skipped_documents, + processed_documents, + ) + + @telemetry_event("IngestFiles") + async def ingest_files( + self, + files: list[UploadFile], + metadatas: Optional[list[dict]] = None, + document_ids: Optional[list[uuid.UUID]] = None, + versions: Optional[list[str]] = None, + *args: Any, + **kwargs: Any, + ): + if not files: + raise R2RException( + status_code=400, message="No files provided for ingestion." + ) + + try: + documents = [] + for iteration, file in enumerate(files): + logger.info(f"Processing file: {file.filename}") + if ( + file.size + > self.config.app.get("max_file_size_in_mb", 32) + * MB_CONVERSION_FACTOR + ): + raise R2RException( + status_code=413, + message=f"File size exceeds maximum allowed size: {file.filename}", + ) + if not file.filename: + raise R2RException( + status_code=400, message="File name not provided." + ) + + document_metadata = metadatas[iteration] if metadatas else {} + document_id = ( + document_ids[iteration] + if document_ids + else generate_id_from_label(file.filename.split("/")[-1]) + ) + + document = self._file_to_document( + file, document_id, document_metadata + ) + documents.append(document) + + return await self.ingest_documents( + documents, versions, *args, **kwargs + ) + + finally: + for file in files: + file.file.close() + + @telemetry_event("UpdateFiles") + async def update_files( + self, + files: list[UploadFile], + document_ids: list[uuid.UUID], + metadatas: Optional[list[dict]] = None, + *args: Any, + **kwargs: Any, + ): + if not files: + raise R2RException( + status_code=400, message="No files provided for update." + ) + + try: + if len(document_ids) != len(files): + raise R2RException( + status_code=400, + message="Number of ids does not match number of files.", + ) + + documents_overview = await self._documents_overview( + document_ids=document_ids + ) + if len(documents_overview) != len(files): + raise R2RException( + status_code=404, + message="One or more documents was not found.", + ) + + documents = [] + new_versions = [] + + for it, (file, doc_id, doc_info) in enumerate( + zip(files, document_ids, documents_overview) + ): + if not doc_info: + raise R2RException( + status_code=404, + message=f"Document with id {doc_id} not found.", + ) + + new_version = increment_version(doc_info.version) + new_versions.append(new_version) + + updated_metadata = ( + metadatas[it] if metadatas else doc_info.metadata + ) + updated_metadata["title"] = ( + updated_metadata.get("title", None) + or file.filename.split("/")[-1] + ) + + document = self._file_to_document( + file, doc_id, updated_metadata + ) + documents.append(document) + + ingestion_results = await self.ingest_documents( + documents, versions=new_versions, *args, **kwargs + ) + + for doc_id, old_version in zip( + document_ids, + [doc_info.version for doc_info in documents_overview], + ): + await self._delete( + ["document_id", "version"], [str(doc_id), old_version] + ) + self.providers.vector_db.delete_from_documents_overview( + doc_id, old_version + ) + + return ingestion_results + + finally: + for file in files: + file.file.close() + + async def _process_ingestion_results( + self, + ingestion_results: dict, + document_infos: list[DocumentInfo], + skipped_documents: list[tuple[str, str]], + processed_documents: dict, + ): + skipped_ids = [ele[0] for ele in skipped_documents] + failed_ids = [] + successful_ids = [] + + results = {} + if ingestion_results["embedding_pipeline_output"]: + results = { + k: v for k, v in ingestion_results["embedding_pipeline_output"] + } + for doc_id, error in results.items(): + if isinstance(error, R2RDocumentProcessingError): + logger.error( + f"Error processing document with ID {error.document_id}: {error.message}" + ) + failed_ids.append(error.document_id) + elif isinstance(error, Exception): + logger.error(f"Error processing document: {error}") + failed_ids.append(doc_id) + else: + successful_ids.append(doc_id) + + documents_to_upsert = [] + for document_info in document_infos: + if document_info.document_id not in skipped_ids: + if document_info.document_id in failed_ids: + document_info.status = "failure" + elif document_info.document_id in successful_ids: + document_info.status = "success" + documents_to_upsert.append(document_info) + + if documents_to_upsert: + self.providers.vector_db.upsert_documents_overview( + documents_to_upsert + ) + + results = { + "processed_documents": [ + f"Document '{processed_documents[document_id]}' processed successfully." + for document_id in successful_ids + ], + "failed_documents": [ + f"Document '{processed_documents[document_id]}': {results[document_id]}" + for document_id in failed_ids + ], + "skipped_documents": [ + f"Document '{filename}' skipped since it already exists." + for _, filename in skipped_documents + ], + } + + # TODO - Clean up logging for document parse results + run_ids = list(self.run_manager.run_info.keys()) + if run_ids: + run_id = run_ids[0] + for key in results: + if key in ["processed_documents", "failed_documents"]: + for value in results[key]: + await self.logging_connection.log( + log_id=run_id, + key="document_parse_result", + value=value, + ) + return results + + @staticmethod + def parse_ingest_files_form_data( + metadatas: Optional[str] = Form(None), + document_ids: str = Form(None), + versions: Optional[str] = Form(None), + ) -> R2RIngestFilesRequest: + try: + parsed_metadatas = ( + json.loads(metadatas) + if metadatas and metadatas != "null" + else None + ) + if parsed_metadatas is not None and not isinstance( + parsed_metadatas, list + ): + raise ValueError("metadatas must be a list of dictionaries") + + parsed_document_ids = ( + json.loads(document_ids) + if document_ids and document_ids != "null" + else None + ) + if parsed_document_ids is not None: + parsed_document_ids = [ + uuid.UUID(doc_id) for doc_id in parsed_document_ids + ] + + parsed_versions = ( + json.loads(versions) + if versions and versions != "null" + else None + ) + + request_data = { + "metadatas": parsed_metadatas, + "document_ids": parsed_document_ids, + "versions": parsed_versions, + } + return R2RIngestFilesRequest(**request_data) + except json.JSONDecodeError as e: + raise R2RException( + status_code=400, message=f"Invalid JSON in form data: {e}" + ) + except ValueError as e: + raise R2RException(status_code=400, message=str(e)) + except Exception as e: + raise R2RException( + status_code=400, message=f"Error processing form data: {e}" + ) + + @staticmethod + def parse_update_files_form_data( + metadatas: Optional[str] = Form(None), + document_ids: str = Form(...), + ) -> R2RUpdateFilesRequest: + try: + parsed_metadatas = ( + json.loads(metadatas) + if metadatas and metadatas != "null" + else None + ) + if parsed_metadatas is not None and not isinstance( + parsed_metadatas, list + ): + raise ValueError("metadatas must be a list of dictionaries") + + if not document_ids or document_ids == "null": + raise ValueError("document_ids is required and cannot be null") + + parsed_document_ids = json.loads(document_ids) + if not isinstance(parsed_document_ids, list): + raise ValueError("document_ids must be a list") + parsed_document_ids = [ + uuid.UUID(doc_id) for doc_id in parsed_document_ids + ] + + request_data = { + "metadatas": parsed_metadatas, + "document_ids": parsed_document_ids, + } + return R2RUpdateFilesRequest(**request_data) + except json.JSONDecodeError as e: + raise R2RException( + status_code=400, message=f"Invalid JSON in form data: {e}" + ) + except ValueError as e: + raise R2RException(status_code=400, message=str(e)) + except Exception as e: + raise R2RException( + status_code=400, message=f"Error processing form data: {e}" + ) + + # TODO - Move to mgmt service for document info, delete, post orchestration buildout + async def _documents_overview( + self, + document_ids: Optional[list[uuid.UUID]] = None, + user_ids: Optional[list[uuid.UUID]] = None, + *args: Any, + **kwargs: Any, + ): + return self.providers.vector_db.get_documents_overview( + filter_document_ids=( + [str(ele) for ele in document_ids] if document_ids else None + ), + filter_user_ids=( + [str(ele) for ele in user_ids] if user_ids else None + ), + ) + + async def _delete( + self, keys: list[str], values: list[Union[bool, int, str]] + ): + logger.info( + f"Deleting documents which match on these keys and values: ({keys}, {values})" + ) + + ids = self.providers.vector_db.delete_by_metadata(keys, values) + if not ids: + raise R2RException( + status_code=404, message="No entries found for deletion." + ) + return "Entries deleted successfully." diff --git a/R2R/r2r/main/services/management_service.py b/R2R/r2r/main/services/management_service.py new file mode 100755 index 00000000..00f1f56e --- /dev/null +++ b/R2R/r2r/main/services/management_service.py @@ -0,0 +1,385 @@ +import logging +import uuid +from collections import defaultdict +from typing import Any, Dict, List, Optional, Tuple, Union + +from r2r.base import ( + AnalysisTypes, + FilterCriteria, + KVLoggingSingleton, + LogProcessor, + R2RException, + RunManager, +) +from r2r.telemetry.telemetry_decorator import telemetry_event + +from ..abstractions import R2RPipelines, R2RProviders +from ..assembly.config import R2RConfig +from .base import Service + +logger = logging.getLogger(__name__) + + +class ManagementService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + pipelines: R2RPipelines, + run_manager: RunManager, + logging_connection: KVLoggingSingleton, + ): + super().__init__( + config, providers, pipelines, run_manager, logging_connection + ) + + @telemetry_event("UpdatePrompt") + async def update_prompt( + self, + name: str, + template: Optional[str] = None, + input_types: Optional[dict[str, str]] = {}, + *args, + **kwargs, + ): + self.providers.prompt.update_prompt(name, template, input_types) + return f"Prompt '{name}' added successfully." + + @telemetry_event("Logs") + async def alogs( + self, + log_type_filter: Optional[str] = None, + max_runs_requested: int = 100, + *args: Any, + **kwargs: Any, + ): + if self.logging_connection is None: + raise R2RException( + status_code=404, message="Logging provider not found." + ) + if ( + self.config.app.get("max_logs_per_request", 100) + > max_runs_requested + ): + raise R2RException( + status_code=400, + message="Max runs requested exceeds the limit.", + ) + + run_info = await self.logging_connection.get_run_info( + limit=max_runs_requested, + log_type_filter=log_type_filter, + ) + run_ids = [run.run_id for run in run_info] + if len(run_ids) == 0: + return [] + logs = await self.logging_connection.get_logs(run_ids) + # Aggregate logs by run_id and include run_type + aggregated_logs = [] + + for run in run_info: + run_logs = [log for log in logs if log["log_id"] == run.run_id] + entries = [ + {"key": log["key"], "value": log["value"]} for log in run_logs + ][ + ::-1 + ] # Reverse order so that earliest logged values appear first. + aggregated_logs.append( + { + "run_id": run.run_id, + "run_type": run.log_type, + "entries": entries, + } + ) + + return aggregated_logs + + @telemetry_event("Analytics") + async def aanalytics( + self, + filter_criteria: FilterCriteria, + analysis_types: AnalysisTypes, + *args, + **kwargs, + ): + run_info = await self.logging_connection.get_run_info(limit=100) + run_ids = [info.run_id for info in run_info] + + if not run_ids: + return { + "analytics_data": "No logs found.", + "filtered_logs": {}, + } + logs = await self.logging_connection.get_logs(run_ids=run_ids) + + filters = {} + if filter_criteria.filters: + for key, value in filter_criteria.filters.items(): + filters[key] = lambda log, value=value: ( + any( + entry.get("key") == value + for entry in log.get("entries", []) + ) + if "entries" in log + else log.get("key") == value + ) + + log_processor = LogProcessor(filters) + for log in logs: + if "entries" in log and isinstance(log["entries"], list): + log_processor.process_log(log) + elif "key" in log: + log_processor.process_log(log) + else: + logger.warning( + f"Skipping log due to missing or malformed 'entries': {log}" + ) + + filtered_logs = dict(log_processor.populations.items()) + results = {"filtered_logs": filtered_logs} + + if analysis_types and analysis_types.analysis_types: + for ( + filter_key, + analysis_config, + ) in analysis_types.analysis_types.items(): + if filter_key in filtered_logs: + analysis_type = analysis_config[0] + if analysis_type == "bar_chart": + extract_key = analysis_config[1] + results[filter_key] = ( + AnalysisTypes.generate_bar_chart_data( + filtered_logs[filter_key], extract_key + ) + ) + elif analysis_type == "basic_statistics": + extract_key = analysis_config[1] + results[filter_key] = ( + AnalysisTypes.calculate_basic_statistics( + filtered_logs[filter_key], extract_key + ) + ) + elif analysis_type == "percentile": + extract_key = analysis_config[1] + percentile = int(analysis_config[2]) + results[filter_key] = ( + AnalysisTypes.calculate_percentile( + filtered_logs[filter_key], + extract_key, + percentile, + ) + ) + else: + logger.warning( + f"Unknown analysis type for filter key '{filter_key}': {analysis_type}" + ) + + return results + + @telemetry_event("AppSettings") + async def aapp_settings(self, *args: Any, **kwargs: Any): + prompts = self.providers.prompt.get_all_prompts() + return { + "config": self.config.to_json(), + "prompts": { + name: prompt.dict() for name, prompt in prompts.items() + }, + } + + @telemetry_event("UsersOverview") + async def ausers_overview( + self, + user_ids: Optional[list[uuid.UUID]] = None, + *args, + **kwargs, + ): + return self.providers.vector_db.get_users_overview( + [str(ele) for ele in user_ids] if user_ids else None + ) + + @telemetry_event("Delete") + async def delete( + self, + keys: list[str], + values: list[Union[bool, int, str]], + *args, + **kwargs, + ): + metadata = ", ".join( + f"{key}={value}" for key, value in zip(keys, values) + ) + values = [str(value) for value in values] + logger.info(f"Deleting entries with metadata: {metadata}") + ids = self.providers.vector_db.delete_by_metadata(keys, values) + if not ids: + raise R2RException( + status_code=404, message="No entries found for deletion." + ) + for id in ids: + self.providers.vector_db.delete_from_documents_overview(id) + return f"Documents {ids} deleted successfully." + + @telemetry_event("DocumentsOverview") + async def adocuments_overview( + self, + document_ids: Optional[list[uuid.UUID]] = None, + user_ids: Optional[list[uuid.UUID]] = None, + *args: Any, + **kwargs: Any, + ): + return self.providers.vector_db.get_documents_overview( + filter_document_ids=( + [str(ele) for ele in document_ids] if document_ids else None + ), + filter_user_ids=( + [str(ele) for ele in user_ids] if user_ids else None + ), + ) + + @telemetry_event("DocumentChunks") + async def document_chunks( + self, + document_id: uuid.UUID, + *args, + **kwargs, + ): + return self.providers.vector_db.get_document_chunks(str(document_id)) + + @telemetry_event("UsersOverview") + async def users_overview( + self, + user_ids: Optional[list[uuid.UUID]], + *args, + **kwargs, + ): + return self.providers.vector_db.get_users_overview( + [str(ele) for ele in user_ids] + ) + + @telemetry_event("InspectKnowledgeGraph") + async def inspect_knowledge_graph( + self, limit=10000, *args: Any, **kwargs: Any + ): + if self.providers.kg is None: + raise R2RException( + status_code=404, message="Knowledge Graph provider not found." + ) + + rel_query = f""" + MATCH (n1)-[r]->(n2) + RETURN n1.id AS subject, type(r) AS relation, n2.id AS object + LIMIT {limit} + """ + + try: + with self.providers.kg.client.session( + database=self.providers.kg._database + ) as session: + results = session.run(rel_query) + relationships = [ + (record["subject"], record["relation"], record["object"]) + for record in results + ] + + # Create graph representation and group relationships + graph, grouped_relationships = self.process_relationships( + relationships + ) + + # Generate output + output = self.generate_output(grouped_relationships, graph) + + return "\n".join(output) + + except Exception as e: + logger.error(f"Error printing relationships: {str(e)}") + raise R2RException( + status_code=500, + message=f"An error occurred while fetching relationships: {str(e)}", + ) + + def process_relationships( + self, relationships: List[Tuple[str, str, str]] + ) -> Tuple[Dict[str, List[str]], Dict[str, Dict[str, List[str]]]]: + graph = defaultdict(list) + grouped = defaultdict(lambda: defaultdict(list)) + for subject, relation, obj in relationships: + graph[subject].append(obj) + grouped[subject][relation].append(obj) + if obj not in graph: + graph[obj] = [] + return dict(graph), dict(grouped) + + def generate_output( + self, + grouped_relationships: Dict[str, Dict[str, List[str]]], + graph: Dict[str, List[str]], + ) -> List[str]: + output = [] + + # Print grouped relationships + for subject, relations in grouped_relationships.items(): + output.append(f"\n== {subject} ==") + for relation, objects in relations.items(): + output.append(f" {relation}:") + for obj in objects: + output.append(f" - {obj}") + + # Print basic graph statistics + output.append("\n== Graph Statistics ==") + output.append(f"Number of nodes: {len(graph)}") + output.append( + f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}" + ) + output.append( + f"Number of connected components: {self.count_connected_components(graph)}" + ) + + # Find central nodes + central_nodes = self.get_central_nodes(graph) + output.append("\n== Most Central Nodes ==") + for node, centrality in central_nodes: + output.append(f" {node}: {centrality:.4f}") + + return output + + def count_connected_components(self, graph: Dict[str, List[str]]) -> int: + visited = set() + components = 0 + + def dfs(node): + visited.add(node) + for neighbor in graph[node]: + if neighbor not in visited: + dfs(neighbor) + + for node in graph: + if node not in visited: + dfs(node) + components += 1 + + return components + + def get_central_nodes( + self, graph: Dict[str, List[str]] + ) -> List[Tuple[str, float]]: + degree = {node: len(neighbors) for node, neighbors in graph.items()} + total_nodes = len(graph) + centrality = { + node: deg / (total_nodes - 1) for node, deg in degree.items() + } + return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5] + + @telemetry_event("AppSettings") + async def app_settings( + self, + *args, + **kwargs, + ): + prompts = self.providers.prompt.get_all_prompts() + return { + "config": self.config.to_json(), + "prompts": { + name: prompt.dict() for name, prompt in prompts.items() + }, + } diff --git a/R2R/r2r/main/services/retrieval_service.py b/R2R/r2r/main/services/retrieval_service.py new file mode 100755 index 00000000..c4f6aff5 --- /dev/null +++ b/R2R/r2r/main/services/retrieval_service.py @@ -0,0 +1,207 @@ +import logging +import time +import uuid +from typing import Optional + +from r2r.base import ( + GenerationConfig, + KGSearchSettings, + KVLoggingSingleton, + R2RException, + RunManager, + VectorSearchSettings, + manage_run, + to_async_generator, +) +from r2r.pipes import EvalPipe +from r2r.telemetry.telemetry_decorator import telemetry_event + +from ..abstractions import R2RPipelines, R2RProviders +from ..assembly.config import R2RConfig +from .base import Service + +logger = logging.getLogger(__name__) + + +class RetrievalService(Service): + def __init__( + self, + config: R2RConfig, + providers: R2RProviders, + pipelines: R2RPipelines, + run_manager: RunManager, + logging_connection: KVLoggingSingleton, + ): + super().__init__( + config, providers, pipelines, run_manager, logging_connection + ) + + @telemetry_event("Search") + async def search( + self, + query: str, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + kg_search_settings: KGSearchSettings = KGSearchSettings(), + *args, + **kwargs, + ): + async with manage_run(self.run_manager, "search_app") as run_id: + t0 = time.time() + + if ( + kg_search_settings.use_kg_search + and self.config.kg.provider is None + ): + raise R2RException( + status_code=400, + message="Knowledge Graph search is not enabled in the configuration.", + ) + + if ( + vector_search_settings.use_vector_search + and self.config.vector_database.provider is None + ): + raise R2RException( + status_code=400, + message="Vector search is not enabled in the configuration.", + ) + + # TODO - Remove these transforms once we have a better way to handle this + for filter, value in vector_search_settings.search_filters.items(): + if isinstance(value, uuid.UUID): + vector_search_settings.search_filters[filter] = str(value) + + results = await self.pipelines.search_pipeline.run( + input=to_async_generator([query]), + vector_search_settings=vector_search_settings, + kg_search_settings=kg_search_settings, + run_manager=self.run_manager, + *args, + **kwargs, + ) + + t1 = time.time() + latency = f"{t1 - t0:.2f}" + + await self.logging_connection.log( + log_id=run_id, + key="search_latency", + value=latency, + is_info_log=False, + ) + + return results.dict() + + @telemetry_event("RAG") + async def rag( + self, + query: str, + rag_generation_config: GenerationConfig, + vector_search_settings: VectorSearchSettings = VectorSearchSettings(), + kg_search_settings: KGSearchSettings = KGSearchSettings(), + *args, + **kwargs, + ): + async with manage_run(self.run_manager, "rag_app") as run_id: + try: + t0 = time.time() + + # TODO - Remove these transforms once we have a better way to handle this + for ( + filter, + value, + ) in vector_search_settings.search_filters.items(): + if isinstance(value, uuid.UUID): + vector_search_settings.search_filters[filter] = str( + value + ) + + if rag_generation_config.stream: + t1 = time.time() + latency = f"{t1 - t0:.2f}" + + await self.logging_connection.log( + log_id=run_id, + key="rag_generation_latency", + value=latency, + is_info_log=False, + ) + + async def stream_response(): + async with manage_run(self.run_manager, "arag"): + async for ( + chunk + ) in await self.pipelines.streaming_rag_pipeline.run( + input=to_async_generator([query]), + run_manager=self.run_manager, + vector_search_settings=vector_search_settings, + kg_search_settings=kg_search_settings, + rag_generation_config=rag_generation_config, + ): + yield chunk + + return stream_response() + + results = await self.pipelines.rag_pipeline.run( + input=to_async_generator([query]), + run_manager=self.run_manager, + vector_search_settings=vector_search_settings, + kg_search_settings=kg_search_settings, + rag_generation_config=rag_generation_config, + *args, + **kwargs, + ) + + t1 = time.time() + latency = f"{t1 - t0:.2f}" + + await self.logging_connection.log( + log_id=run_id, + key="rag_generation_latency", + value=latency, + is_info_log=False, + ) + + if len(results) == 0: + raise R2RException( + status_code=404, message="No results found" + ) + if len(results) > 1: + logger.warning( + f"Multiple results found for query: {query}" + ) + # unpack the first result + return results[0] + + except Exception as e: + logger.error(f"Pipeline error: {str(e)}") + if "NoneType" in str(e): + raise R2RException( + status_code=502, + message="Ollama server not reachable or returned an invalid response", + ) + raise R2RException( + status_code=500, message="Internal Server Error" + ) + + @telemetry_event("Evaluate") + async def evaluate( + self, + query: str, + context: str, + completion: str, + eval_generation_config: Optional[GenerationConfig], + *args, + **kwargs, + ): + eval_payload = EvalPipe.EvalPayload( + query=query, + context=context, + completion=completion, + ) + result = await self.eval_pipeline.run( + input=to_async_generator([eval_payload]), + run_manager=self.run_manager, + eval_generation_config=eval_generation_config, + ) + return result |