aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/main/services
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/main/services')
-rwxr-xr-xR2R/r2r/main/services/__init__.py0
-rwxr-xr-xR2R/r2r/main/services/base.py22
-rwxr-xr-xR2R/r2r/main/services/ingestion_service.py505
-rwxr-xr-xR2R/r2r/main/services/management_service.py385
-rwxr-xr-xR2R/r2r/main/services/retrieval_service.py207
5 files changed, 1119 insertions, 0 deletions
diff --git a/R2R/r2r/main/services/__init__.py b/R2R/r2r/main/services/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/main/services/__init__.py
diff --git a/R2R/r2r/main/services/base.py b/R2R/r2r/main/services/base.py
new file mode 100755
index 00000000..02c0675d
--- /dev/null
+++ b/R2R/r2r/main/services/base.py
@@ -0,0 +1,22 @@
+from abc import ABC
+
+from r2r.base import KVLoggingSingleton, RunManager
+
+from ..abstractions import R2RPipelines, R2RProviders
+from ..assembly.config import R2RConfig
+
+
+class Service(ABC):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ pipelines: R2RPipelines,
+ run_manager: RunManager,
+ logging_connection: KVLoggingSingleton,
+ ):
+ self.config = config
+ self.providers = providers
+ self.pipelines = pipelines
+ self.run_manager = run_manager
+ self.logging_connection = logging_connection
diff --git a/R2R/r2r/main/services/ingestion_service.py b/R2R/r2r/main/services/ingestion_service.py
new file mode 100755
index 00000000..5677807a
--- /dev/null
+++ b/R2R/r2r/main/services/ingestion_service.py
@@ -0,0 +1,505 @@
+import json
+import logging
+import uuid
+from collections import defaultdict
+from datetime import datetime
+from typing import Any, Optional, Union
+
+from fastapi import Form, UploadFile
+
+from r2r.base import (
+ Document,
+ DocumentInfo,
+ DocumentType,
+ KVLoggingSingleton,
+ R2RDocumentProcessingError,
+ R2RException,
+ RunManager,
+ generate_id_from_label,
+ increment_version,
+ to_async_generator,
+)
+from r2r.telemetry.telemetry_decorator import telemetry_event
+
+from ..abstractions import R2RPipelines, R2RProviders
+from ..api.requests import R2RIngestFilesRequest, R2RUpdateFilesRequest
+from ..assembly.config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger(__name__)
+MB_CONVERSION_FACTOR = 1024 * 1024
+
+
+class IngestionService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ pipelines: R2RPipelines,
+ run_manager: RunManager,
+ logging_connection: KVLoggingSingleton,
+ ):
+ super().__init__(
+ config, providers, pipelines, run_manager, logging_connection
+ )
+
+ def _file_to_document(
+ self, file: UploadFile, document_id: uuid.UUID, metadata: dict
+ ) -> Document:
+ file_extension = file.filename.split(".")[-1].lower()
+ if file_extension.upper() not in DocumentType.__members__:
+ raise R2RException(
+ status_code=415,
+ message=f"'{file_extension}' is not a valid DocumentType.",
+ )
+
+ document_title = (
+ metadata.get("title", None) or file.filename.split("/")[-1]
+ )
+ metadata["title"] = document_title
+
+ return Document(
+ id=document_id,
+ type=DocumentType[file_extension.upper()],
+ data=file.file.read(),
+ metadata=metadata,
+ )
+
+ @telemetry_event("IngestDocuments")
+ async def ingest_documents(
+ self,
+ documents: list[Document],
+ versions: Optional[list[str]] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ if len(documents) == 0:
+ raise R2RException(
+ status_code=400, message="No documents provided for ingestion."
+ )
+
+ document_infos = []
+ skipped_documents = []
+ processed_documents = {}
+ duplicate_documents = defaultdict(list)
+
+ existing_document_info = {
+ doc_info.document_id: doc_info
+ for doc_info in self.providers.vector_db.get_documents_overview()
+ }
+
+ for iteration, document in enumerate(documents):
+ version = versions[iteration] if versions else "v0"
+
+ # Check for duplicates within the current batch
+ if document.id in processed_documents:
+ duplicate_documents[document.id].append(
+ document.metadata.get("title", str(document.id))
+ )
+ continue
+
+ if (
+ document.id in existing_document_info
+ and existing_document_info[document.id].version == version
+ and existing_document_info[document.id].status == "success"
+ ):
+ logger.error(
+ f"Document with ID {document.id} was already successfully processed."
+ )
+ if len(documents) == 1:
+ raise R2RException(
+ status_code=409,
+ message=f"Document with ID {document.id} was already successfully processed.",
+ )
+ skipped_documents.append(
+ (
+ document.id,
+ document.metadata.get("title", None)
+ or str(document.id),
+ )
+ )
+ continue
+
+ now = datetime.now()
+ document_infos.append(
+ DocumentInfo(
+ document_id=document.id,
+ version=version,
+ size_in_bytes=len(document.data),
+ metadata=document.metadata.copy(),
+ title=document.metadata.get("title", str(document.id)),
+ user_id=document.metadata.get("user_id", None),
+ created_at=now,
+ updated_at=now,
+ status="processing", # Set initial status to `processing`
+ )
+ )
+
+ processed_documents[document.id] = document.metadata.get(
+ "title", str(document.id)
+ )
+
+ if duplicate_documents:
+ duplicate_details = [
+ f"{doc_id}: {', '.join(titles)}"
+ for doc_id, titles in duplicate_documents.items()
+ ]
+ warning_message = f"Duplicate documents detected: {'; '.join(duplicate_details)}. These duplicates were skipped."
+ raise R2RException(status_code=418, message=warning_message)
+
+ if skipped_documents and len(skipped_documents) == len(documents):
+ logger.error("All provided documents already exist.")
+ raise R2RException(
+ status_code=409,
+ message="All provided documents already exist. Use the `update_documents` endpoint instead to update these documents.",
+ )
+
+ # Insert pending document infos
+ self.providers.vector_db.upsert_documents_overview(document_infos)
+ ingestion_results = await self.pipelines.ingestion_pipeline.run(
+ input=to_async_generator(
+ [
+ doc
+ for doc in documents
+ if doc.id
+ not in [skipped[0] for skipped in skipped_documents]
+ ]
+ ),
+ versions=[info.version for info in document_infos],
+ run_manager=self.run_manager,
+ *args,
+ **kwargs,
+ )
+
+ return await self._process_ingestion_results(
+ ingestion_results,
+ document_infos,
+ skipped_documents,
+ processed_documents,
+ )
+
+ @telemetry_event("IngestFiles")
+ async def ingest_files(
+ self,
+ files: list[UploadFile],
+ metadatas: Optional[list[dict]] = None,
+ document_ids: Optional[list[uuid.UUID]] = None,
+ versions: Optional[list[str]] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ if not files:
+ raise R2RException(
+ status_code=400, message="No files provided for ingestion."
+ )
+
+ try:
+ documents = []
+ for iteration, file in enumerate(files):
+ logger.info(f"Processing file: {file.filename}")
+ if (
+ file.size
+ > self.config.app.get("max_file_size_in_mb", 32)
+ * MB_CONVERSION_FACTOR
+ ):
+ raise R2RException(
+ status_code=413,
+ message=f"File size exceeds maximum allowed size: {file.filename}",
+ )
+ if not file.filename:
+ raise R2RException(
+ status_code=400, message="File name not provided."
+ )
+
+ document_metadata = metadatas[iteration] if metadatas else {}
+ document_id = (
+ document_ids[iteration]
+ if document_ids
+ else generate_id_from_label(file.filename.split("/")[-1])
+ )
+
+ document = self._file_to_document(
+ file, document_id, document_metadata
+ )
+ documents.append(document)
+
+ return await self.ingest_documents(
+ documents, versions, *args, **kwargs
+ )
+
+ finally:
+ for file in files:
+ file.file.close()
+
+ @telemetry_event("UpdateFiles")
+ async def update_files(
+ self,
+ files: list[UploadFile],
+ document_ids: list[uuid.UUID],
+ metadatas: Optional[list[dict]] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ if not files:
+ raise R2RException(
+ status_code=400, message="No files provided for update."
+ )
+
+ try:
+ if len(document_ids) != len(files):
+ raise R2RException(
+ status_code=400,
+ message="Number of ids does not match number of files.",
+ )
+
+ documents_overview = await self._documents_overview(
+ document_ids=document_ids
+ )
+ if len(documents_overview) != len(files):
+ raise R2RException(
+ status_code=404,
+ message="One or more documents was not found.",
+ )
+
+ documents = []
+ new_versions = []
+
+ for it, (file, doc_id, doc_info) in enumerate(
+ zip(files, document_ids, documents_overview)
+ ):
+ if not doc_info:
+ raise R2RException(
+ status_code=404,
+ message=f"Document with id {doc_id} not found.",
+ )
+
+ new_version = increment_version(doc_info.version)
+ new_versions.append(new_version)
+
+ updated_metadata = (
+ metadatas[it] if metadatas else doc_info.metadata
+ )
+ updated_metadata["title"] = (
+ updated_metadata.get("title", None)
+ or file.filename.split("/")[-1]
+ )
+
+ document = self._file_to_document(
+ file, doc_id, updated_metadata
+ )
+ documents.append(document)
+
+ ingestion_results = await self.ingest_documents(
+ documents, versions=new_versions, *args, **kwargs
+ )
+
+ for doc_id, old_version in zip(
+ document_ids,
+ [doc_info.version for doc_info in documents_overview],
+ ):
+ await self._delete(
+ ["document_id", "version"], [str(doc_id), old_version]
+ )
+ self.providers.vector_db.delete_from_documents_overview(
+ doc_id, old_version
+ )
+
+ return ingestion_results
+
+ finally:
+ for file in files:
+ file.file.close()
+
+ async def _process_ingestion_results(
+ self,
+ ingestion_results: dict,
+ document_infos: list[DocumentInfo],
+ skipped_documents: list[tuple[str, str]],
+ processed_documents: dict,
+ ):
+ skipped_ids = [ele[0] for ele in skipped_documents]
+ failed_ids = []
+ successful_ids = []
+
+ results = {}
+ if ingestion_results["embedding_pipeline_output"]:
+ results = {
+ k: v for k, v in ingestion_results["embedding_pipeline_output"]
+ }
+ for doc_id, error in results.items():
+ if isinstance(error, R2RDocumentProcessingError):
+ logger.error(
+ f"Error processing document with ID {error.document_id}: {error.message}"
+ )
+ failed_ids.append(error.document_id)
+ elif isinstance(error, Exception):
+ logger.error(f"Error processing document: {error}")
+ failed_ids.append(doc_id)
+ else:
+ successful_ids.append(doc_id)
+
+ documents_to_upsert = []
+ for document_info in document_infos:
+ if document_info.document_id not in skipped_ids:
+ if document_info.document_id in failed_ids:
+ document_info.status = "failure"
+ elif document_info.document_id in successful_ids:
+ document_info.status = "success"
+ documents_to_upsert.append(document_info)
+
+ if documents_to_upsert:
+ self.providers.vector_db.upsert_documents_overview(
+ documents_to_upsert
+ )
+
+ results = {
+ "processed_documents": [
+ f"Document '{processed_documents[document_id]}' processed successfully."
+ for document_id in successful_ids
+ ],
+ "failed_documents": [
+ f"Document '{processed_documents[document_id]}': {results[document_id]}"
+ for document_id in failed_ids
+ ],
+ "skipped_documents": [
+ f"Document '{filename}' skipped since it already exists."
+ for _, filename in skipped_documents
+ ],
+ }
+
+ # TODO - Clean up logging for document parse results
+ run_ids = list(self.run_manager.run_info.keys())
+ if run_ids:
+ run_id = run_ids[0]
+ for key in results:
+ if key in ["processed_documents", "failed_documents"]:
+ for value in results[key]:
+ await self.logging_connection.log(
+ log_id=run_id,
+ key="document_parse_result",
+ value=value,
+ )
+ return results
+
+ @staticmethod
+ def parse_ingest_files_form_data(
+ metadatas: Optional[str] = Form(None),
+ document_ids: str = Form(None),
+ versions: Optional[str] = Form(None),
+ ) -> R2RIngestFilesRequest:
+ try:
+ parsed_metadatas = (
+ json.loads(metadatas)
+ if metadatas and metadatas != "null"
+ else None
+ )
+ if parsed_metadatas is not None and not isinstance(
+ parsed_metadatas, list
+ ):
+ raise ValueError("metadatas must be a list of dictionaries")
+
+ parsed_document_ids = (
+ json.loads(document_ids)
+ if document_ids and document_ids != "null"
+ else None
+ )
+ if parsed_document_ids is not None:
+ parsed_document_ids = [
+ uuid.UUID(doc_id) for doc_id in parsed_document_ids
+ ]
+
+ parsed_versions = (
+ json.loads(versions)
+ if versions and versions != "null"
+ else None
+ )
+
+ request_data = {
+ "metadatas": parsed_metadatas,
+ "document_ids": parsed_document_ids,
+ "versions": parsed_versions,
+ }
+ return R2RIngestFilesRequest(**request_data)
+ except json.JSONDecodeError as e:
+ raise R2RException(
+ status_code=400, message=f"Invalid JSON in form data: {e}"
+ )
+ except ValueError as e:
+ raise R2RException(status_code=400, message=str(e))
+ except Exception as e:
+ raise R2RException(
+ status_code=400, message=f"Error processing form data: {e}"
+ )
+
+ @staticmethod
+ def parse_update_files_form_data(
+ metadatas: Optional[str] = Form(None),
+ document_ids: str = Form(...),
+ ) -> R2RUpdateFilesRequest:
+ try:
+ parsed_metadatas = (
+ json.loads(metadatas)
+ if metadatas and metadatas != "null"
+ else None
+ )
+ if parsed_metadatas is not None and not isinstance(
+ parsed_metadatas, list
+ ):
+ raise ValueError("metadatas must be a list of dictionaries")
+
+ if not document_ids or document_ids == "null":
+ raise ValueError("document_ids is required and cannot be null")
+
+ parsed_document_ids = json.loads(document_ids)
+ if not isinstance(parsed_document_ids, list):
+ raise ValueError("document_ids must be a list")
+ parsed_document_ids = [
+ uuid.UUID(doc_id) for doc_id in parsed_document_ids
+ ]
+
+ request_data = {
+ "metadatas": parsed_metadatas,
+ "document_ids": parsed_document_ids,
+ }
+ return R2RUpdateFilesRequest(**request_data)
+ except json.JSONDecodeError as e:
+ raise R2RException(
+ status_code=400, message=f"Invalid JSON in form data: {e}"
+ )
+ except ValueError as e:
+ raise R2RException(status_code=400, message=str(e))
+ except Exception as e:
+ raise R2RException(
+ status_code=400, message=f"Error processing form data: {e}"
+ )
+
+ # TODO - Move to mgmt service for document info, delete, post orchestration buildout
+ async def _documents_overview(
+ self,
+ document_ids: Optional[list[uuid.UUID]] = None,
+ user_ids: Optional[list[uuid.UUID]] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ return self.providers.vector_db.get_documents_overview(
+ filter_document_ids=(
+ [str(ele) for ele in document_ids] if document_ids else None
+ ),
+ filter_user_ids=(
+ [str(ele) for ele in user_ids] if user_ids else None
+ ),
+ )
+
+ async def _delete(
+ self, keys: list[str], values: list[Union[bool, int, str]]
+ ):
+ logger.info(
+ f"Deleting documents which match on these keys and values: ({keys}, {values})"
+ )
+
+ ids = self.providers.vector_db.delete_by_metadata(keys, values)
+ if not ids:
+ raise R2RException(
+ status_code=404, message="No entries found for deletion."
+ )
+ return "Entries deleted successfully."
diff --git a/R2R/r2r/main/services/management_service.py b/R2R/r2r/main/services/management_service.py
new file mode 100755
index 00000000..00f1f56e
--- /dev/null
+++ b/R2R/r2r/main/services/management_service.py
@@ -0,0 +1,385 @@
+import logging
+import uuid
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from r2r.base import (
+ AnalysisTypes,
+ FilterCriteria,
+ KVLoggingSingleton,
+ LogProcessor,
+ R2RException,
+ RunManager,
+)
+from r2r.telemetry.telemetry_decorator import telemetry_event
+
+from ..abstractions import R2RPipelines, R2RProviders
+from ..assembly.config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger(__name__)
+
+
+class ManagementService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ pipelines: R2RPipelines,
+ run_manager: RunManager,
+ logging_connection: KVLoggingSingleton,
+ ):
+ super().__init__(
+ config, providers, pipelines, run_manager, logging_connection
+ )
+
+ @telemetry_event("UpdatePrompt")
+ async def update_prompt(
+ self,
+ name: str,
+ template: Optional[str] = None,
+ input_types: Optional[dict[str, str]] = {},
+ *args,
+ **kwargs,
+ ):
+ self.providers.prompt.update_prompt(name, template, input_types)
+ return f"Prompt '{name}' added successfully."
+
+ @telemetry_event("Logs")
+ async def alogs(
+ self,
+ log_type_filter: Optional[str] = None,
+ max_runs_requested: int = 100,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ if self.logging_connection is None:
+ raise R2RException(
+ status_code=404, message="Logging provider not found."
+ )
+ if (
+ self.config.app.get("max_logs_per_request", 100)
+ > max_runs_requested
+ ):
+ raise R2RException(
+ status_code=400,
+ message="Max runs requested exceeds the limit.",
+ )
+
+ run_info = await self.logging_connection.get_run_info(
+ limit=max_runs_requested,
+ log_type_filter=log_type_filter,
+ )
+ run_ids = [run.run_id for run in run_info]
+ if len(run_ids) == 0:
+ return []
+ logs = await self.logging_connection.get_logs(run_ids)
+ # Aggregate logs by run_id and include run_type
+ aggregated_logs = []
+
+ for run in run_info:
+ run_logs = [log for log in logs if log["log_id"] == run.run_id]
+ entries = [
+ {"key": log["key"], "value": log["value"]} for log in run_logs
+ ][
+ ::-1
+ ] # Reverse order so that earliest logged values appear first.
+ aggregated_logs.append(
+ {
+ "run_id": run.run_id,
+ "run_type": run.log_type,
+ "entries": entries,
+ }
+ )
+
+ return aggregated_logs
+
+ @telemetry_event("Analytics")
+ async def aanalytics(
+ self,
+ filter_criteria: FilterCriteria,
+ analysis_types: AnalysisTypes,
+ *args,
+ **kwargs,
+ ):
+ run_info = await self.logging_connection.get_run_info(limit=100)
+ run_ids = [info.run_id for info in run_info]
+
+ if not run_ids:
+ return {
+ "analytics_data": "No logs found.",
+ "filtered_logs": {},
+ }
+ logs = await self.logging_connection.get_logs(run_ids=run_ids)
+
+ filters = {}
+ if filter_criteria.filters:
+ for key, value in filter_criteria.filters.items():
+ filters[key] = lambda log, value=value: (
+ any(
+ entry.get("key") == value
+ for entry in log.get("entries", [])
+ )
+ if "entries" in log
+ else log.get("key") == value
+ )
+
+ log_processor = LogProcessor(filters)
+ for log in logs:
+ if "entries" in log and isinstance(log["entries"], list):
+ log_processor.process_log(log)
+ elif "key" in log:
+ log_processor.process_log(log)
+ else:
+ logger.warning(
+ f"Skipping log due to missing or malformed 'entries': {log}"
+ )
+
+ filtered_logs = dict(log_processor.populations.items())
+ results = {"filtered_logs": filtered_logs}
+
+ if analysis_types and analysis_types.analysis_types:
+ for (
+ filter_key,
+ analysis_config,
+ ) in analysis_types.analysis_types.items():
+ if filter_key in filtered_logs:
+ analysis_type = analysis_config[0]
+ if analysis_type == "bar_chart":
+ extract_key = analysis_config[1]
+ results[filter_key] = (
+ AnalysisTypes.generate_bar_chart_data(
+ filtered_logs[filter_key], extract_key
+ )
+ )
+ elif analysis_type == "basic_statistics":
+ extract_key = analysis_config[1]
+ results[filter_key] = (
+ AnalysisTypes.calculate_basic_statistics(
+ filtered_logs[filter_key], extract_key
+ )
+ )
+ elif analysis_type == "percentile":
+ extract_key = analysis_config[1]
+ percentile = int(analysis_config[2])
+ results[filter_key] = (
+ AnalysisTypes.calculate_percentile(
+ filtered_logs[filter_key],
+ extract_key,
+ percentile,
+ )
+ )
+ else:
+ logger.warning(
+ f"Unknown analysis type for filter key '{filter_key}': {analysis_type}"
+ )
+
+ return results
+
+ @telemetry_event("AppSettings")
+ async def aapp_settings(self, *args: Any, **kwargs: Any):
+ prompts = self.providers.prompt.get_all_prompts()
+ return {
+ "config": self.config.to_json(),
+ "prompts": {
+ name: prompt.dict() for name, prompt in prompts.items()
+ },
+ }
+
+ @telemetry_event("UsersOverview")
+ async def ausers_overview(
+ self,
+ user_ids: Optional[list[uuid.UUID]] = None,
+ *args,
+ **kwargs,
+ ):
+ return self.providers.vector_db.get_users_overview(
+ [str(ele) for ele in user_ids] if user_ids else None
+ )
+
+ @telemetry_event("Delete")
+ async def delete(
+ self,
+ keys: list[str],
+ values: list[Union[bool, int, str]],
+ *args,
+ **kwargs,
+ ):
+ metadata = ", ".join(
+ f"{key}={value}" for key, value in zip(keys, values)
+ )
+ values = [str(value) for value in values]
+ logger.info(f"Deleting entries with metadata: {metadata}")
+ ids = self.providers.vector_db.delete_by_metadata(keys, values)
+ if not ids:
+ raise R2RException(
+ status_code=404, message="No entries found for deletion."
+ )
+ for id in ids:
+ self.providers.vector_db.delete_from_documents_overview(id)
+ return f"Documents {ids} deleted successfully."
+
+ @telemetry_event("DocumentsOverview")
+ async def adocuments_overview(
+ self,
+ document_ids: Optional[list[uuid.UUID]] = None,
+ user_ids: Optional[list[uuid.UUID]] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ return self.providers.vector_db.get_documents_overview(
+ filter_document_ids=(
+ [str(ele) for ele in document_ids] if document_ids else None
+ ),
+ filter_user_ids=(
+ [str(ele) for ele in user_ids] if user_ids else None
+ ),
+ )
+
+ @telemetry_event("DocumentChunks")
+ async def document_chunks(
+ self,
+ document_id: uuid.UUID,
+ *args,
+ **kwargs,
+ ):
+ return self.providers.vector_db.get_document_chunks(str(document_id))
+
+ @telemetry_event("UsersOverview")
+ async def users_overview(
+ self,
+ user_ids: Optional[list[uuid.UUID]],
+ *args,
+ **kwargs,
+ ):
+ return self.providers.vector_db.get_users_overview(
+ [str(ele) for ele in user_ids]
+ )
+
+ @telemetry_event("InspectKnowledgeGraph")
+ async def inspect_knowledge_graph(
+ self, limit=10000, *args: Any, **kwargs: Any
+ ):
+ if self.providers.kg is None:
+ raise R2RException(
+ status_code=404, message="Knowledge Graph provider not found."
+ )
+
+ rel_query = f"""
+ MATCH (n1)-[r]->(n2)
+ RETURN n1.id AS subject, type(r) AS relation, n2.id AS object
+ LIMIT {limit}
+ """
+
+ try:
+ with self.providers.kg.client.session(
+ database=self.providers.kg._database
+ ) as session:
+ results = session.run(rel_query)
+ relationships = [
+ (record["subject"], record["relation"], record["object"])
+ for record in results
+ ]
+
+ # Create graph representation and group relationships
+ graph, grouped_relationships = self.process_relationships(
+ relationships
+ )
+
+ # Generate output
+ output = self.generate_output(grouped_relationships, graph)
+
+ return "\n".join(output)
+
+ except Exception as e:
+ logger.error(f"Error printing relationships: {str(e)}")
+ raise R2RException(
+ status_code=500,
+ message=f"An error occurred while fetching relationships: {str(e)}",
+ )
+
+ def process_relationships(
+ self, relationships: List[Tuple[str, str, str]]
+ ) -> Tuple[Dict[str, List[str]], Dict[str, Dict[str, List[str]]]]:
+ graph = defaultdict(list)
+ grouped = defaultdict(lambda: defaultdict(list))
+ for subject, relation, obj in relationships:
+ graph[subject].append(obj)
+ grouped[subject][relation].append(obj)
+ if obj not in graph:
+ graph[obj] = []
+ return dict(graph), dict(grouped)
+
+ def generate_output(
+ self,
+ grouped_relationships: Dict[str, Dict[str, List[str]]],
+ graph: Dict[str, List[str]],
+ ) -> List[str]:
+ output = []
+
+ # Print grouped relationships
+ for subject, relations in grouped_relationships.items():
+ output.append(f"\n== {subject} ==")
+ for relation, objects in relations.items():
+ output.append(f" {relation}:")
+ for obj in objects:
+ output.append(f" - {obj}")
+
+ # Print basic graph statistics
+ output.append("\n== Graph Statistics ==")
+ output.append(f"Number of nodes: {len(graph)}")
+ output.append(
+ f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}"
+ )
+ output.append(
+ f"Number of connected components: {self.count_connected_components(graph)}"
+ )
+
+ # Find central nodes
+ central_nodes = self.get_central_nodes(graph)
+ output.append("\n== Most Central Nodes ==")
+ for node, centrality in central_nodes:
+ output.append(f" {node}: {centrality:.4f}")
+
+ return output
+
+ def count_connected_components(self, graph: Dict[str, List[str]]) -> int:
+ visited = set()
+ components = 0
+
+ def dfs(node):
+ visited.add(node)
+ for neighbor in graph[node]:
+ if neighbor not in visited:
+ dfs(neighbor)
+
+ for node in graph:
+ if node not in visited:
+ dfs(node)
+ components += 1
+
+ return components
+
+ def get_central_nodes(
+ self, graph: Dict[str, List[str]]
+ ) -> List[Tuple[str, float]]:
+ degree = {node: len(neighbors) for node, neighbors in graph.items()}
+ total_nodes = len(graph)
+ centrality = {
+ node: deg / (total_nodes - 1) for node, deg in degree.items()
+ }
+ return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]
+
+ @telemetry_event("AppSettings")
+ async def app_settings(
+ self,
+ *args,
+ **kwargs,
+ ):
+ prompts = self.providers.prompt.get_all_prompts()
+ return {
+ "config": self.config.to_json(),
+ "prompts": {
+ name: prompt.dict() for name, prompt in prompts.items()
+ },
+ }
diff --git a/R2R/r2r/main/services/retrieval_service.py b/R2R/r2r/main/services/retrieval_service.py
new file mode 100755
index 00000000..c4f6aff5
--- /dev/null
+++ b/R2R/r2r/main/services/retrieval_service.py
@@ -0,0 +1,207 @@
+import logging
+import time
+import uuid
+from typing import Optional
+
+from r2r.base import (
+ GenerationConfig,
+ KGSearchSettings,
+ KVLoggingSingleton,
+ R2RException,
+ RunManager,
+ VectorSearchSettings,
+ manage_run,
+ to_async_generator,
+)
+from r2r.pipes import EvalPipe
+from r2r.telemetry.telemetry_decorator import telemetry_event
+
+from ..abstractions import R2RPipelines, R2RProviders
+from ..assembly.config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger(__name__)
+
+
+class RetrievalService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ pipelines: R2RPipelines,
+ run_manager: RunManager,
+ logging_connection: KVLoggingSingleton,
+ ):
+ super().__init__(
+ config, providers, pipelines, run_manager, logging_connection
+ )
+
+ @telemetry_event("Search")
+ async def search(
+ self,
+ query: str,
+ vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
+ kg_search_settings: KGSearchSettings = KGSearchSettings(),
+ *args,
+ **kwargs,
+ ):
+ async with manage_run(self.run_manager, "search_app") as run_id:
+ t0 = time.time()
+
+ if (
+ kg_search_settings.use_kg_search
+ and self.config.kg.provider is None
+ ):
+ raise R2RException(
+ status_code=400,
+ message="Knowledge Graph search is not enabled in the configuration.",
+ )
+
+ if (
+ vector_search_settings.use_vector_search
+ and self.config.vector_database.provider is None
+ ):
+ raise R2RException(
+ status_code=400,
+ message="Vector search is not enabled in the configuration.",
+ )
+
+ # TODO - Remove these transforms once we have a better way to handle this
+ for filter, value in vector_search_settings.search_filters.items():
+ if isinstance(value, uuid.UUID):
+ vector_search_settings.search_filters[filter] = str(value)
+
+ results = await self.pipelines.search_pipeline.run(
+ input=to_async_generator([query]),
+ vector_search_settings=vector_search_settings,
+ kg_search_settings=kg_search_settings,
+ run_manager=self.run_manager,
+ *args,
+ **kwargs,
+ )
+
+ t1 = time.time()
+ latency = f"{t1 - t0:.2f}"
+
+ await self.logging_connection.log(
+ log_id=run_id,
+ key="search_latency",
+ value=latency,
+ is_info_log=False,
+ )
+
+ return results.dict()
+
+ @telemetry_event("RAG")
+ async def rag(
+ self,
+ query: str,
+ rag_generation_config: GenerationConfig,
+ vector_search_settings: VectorSearchSettings = VectorSearchSettings(),
+ kg_search_settings: KGSearchSettings = KGSearchSettings(),
+ *args,
+ **kwargs,
+ ):
+ async with manage_run(self.run_manager, "rag_app") as run_id:
+ try:
+ t0 = time.time()
+
+ # TODO - Remove these transforms once we have a better way to handle this
+ for (
+ filter,
+ value,
+ ) in vector_search_settings.search_filters.items():
+ if isinstance(value, uuid.UUID):
+ vector_search_settings.search_filters[filter] = str(
+ value
+ )
+
+ if rag_generation_config.stream:
+ t1 = time.time()
+ latency = f"{t1 - t0:.2f}"
+
+ await self.logging_connection.log(
+ log_id=run_id,
+ key="rag_generation_latency",
+ value=latency,
+ is_info_log=False,
+ )
+
+ async def stream_response():
+ async with manage_run(self.run_manager, "arag"):
+ async for (
+ chunk
+ ) in await self.pipelines.streaming_rag_pipeline.run(
+ input=to_async_generator([query]),
+ run_manager=self.run_manager,
+ vector_search_settings=vector_search_settings,
+ kg_search_settings=kg_search_settings,
+ rag_generation_config=rag_generation_config,
+ ):
+ yield chunk
+
+ return stream_response()
+
+ results = await self.pipelines.rag_pipeline.run(
+ input=to_async_generator([query]),
+ run_manager=self.run_manager,
+ vector_search_settings=vector_search_settings,
+ kg_search_settings=kg_search_settings,
+ rag_generation_config=rag_generation_config,
+ *args,
+ **kwargs,
+ )
+
+ t1 = time.time()
+ latency = f"{t1 - t0:.2f}"
+
+ await self.logging_connection.log(
+ log_id=run_id,
+ key="rag_generation_latency",
+ value=latency,
+ is_info_log=False,
+ )
+
+ if len(results) == 0:
+ raise R2RException(
+ status_code=404, message="No results found"
+ )
+ if len(results) > 1:
+ logger.warning(
+ f"Multiple results found for query: {query}"
+ )
+ # unpack the first result
+ return results[0]
+
+ except Exception as e:
+ logger.error(f"Pipeline error: {str(e)}")
+ if "NoneType" in str(e):
+ raise R2RException(
+ status_code=502,
+ message="Ollama server not reachable or returned an invalid response",
+ )
+ raise R2RException(
+ status_code=500, message="Internal Server Error"
+ )
+
+ @telemetry_event("Evaluate")
+ async def evaluate(
+ self,
+ query: str,
+ context: str,
+ completion: str,
+ eval_generation_config: Optional[GenerationConfig],
+ *args,
+ **kwargs,
+ ):
+ eval_payload = EvalPipe.EvalPayload(
+ query=query,
+ context=context,
+ completion=completion,
+ )
+ result = await self.eval_pipeline.run(
+ input=to_async_generator([eval_payload]),
+ run_manager=self.run_manager,
+ eval_generation_config=eval_generation_config,
+ )
+ return result