aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/main/services/management_service.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/services/management_service.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/main/services/management_service.py')
-rwxr-xr-xR2R/r2r/main/services/management_service.py385
1 files changed, 385 insertions, 0 deletions
diff --git a/R2R/r2r/main/services/management_service.py b/R2R/r2r/main/services/management_service.py
new file mode 100755
index 00000000..00f1f56e
--- /dev/null
+++ b/R2R/r2r/main/services/management_service.py
@@ -0,0 +1,385 @@
+import logging
+import uuid
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from r2r.base import (
+ AnalysisTypes,
+ FilterCriteria,
+ KVLoggingSingleton,
+ LogProcessor,
+ R2RException,
+ RunManager,
+)
+from r2r.telemetry.telemetry_decorator import telemetry_event
+
+from ..abstractions import R2RPipelines, R2RProviders
+from ..assembly.config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger(__name__)
+
+
+class ManagementService(Service):
+ def __init__(
+ self,
+ config: R2RConfig,
+ providers: R2RProviders,
+ pipelines: R2RPipelines,
+ run_manager: RunManager,
+ logging_connection: KVLoggingSingleton,
+ ):
+ super().__init__(
+ config, providers, pipelines, run_manager, logging_connection
+ )
+
+ @telemetry_event("UpdatePrompt")
+ async def update_prompt(
+ self,
+ name: str,
+ template: Optional[str] = None,
+ input_types: Optional[dict[str, str]] = {},
+ *args,
+ **kwargs,
+ ):
+ self.providers.prompt.update_prompt(name, template, input_types)
+ return f"Prompt '{name}' added successfully."
+
+ @telemetry_event("Logs")
+ async def alogs(
+ self,
+ log_type_filter: Optional[str] = None,
+ max_runs_requested: int = 100,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ if self.logging_connection is None:
+ raise R2RException(
+ status_code=404, message="Logging provider not found."
+ )
+ if (
+ self.config.app.get("max_logs_per_request", 100)
+ > max_runs_requested
+ ):
+ raise R2RException(
+ status_code=400,
+ message="Max runs requested exceeds the limit.",
+ )
+
+ run_info = await self.logging_connection.get_run_info(
+ limit=max_runs_requested,
+ log_type_filter=log_type_filter,
+ )
+ run_ids = [run.run_id for run in run_info]
+ if len(run_ids) == 0:
+ return []
+ logs = await self.logging_connection.get_logs(run_ids)
+ # Aggregate logs by run_id and include run_type
+ aggregated_logs = []
+
+ for run in run_info:
+ run_logs = [log for log in logs if log["log_id"] == run.run_id]
+ entries = [
+ {"key": log["key"], "value": log["value"]} for log in run_logs
+ ][
+ ::-1
+ ] # Reverse order so that earliest logged values appear first.
+ aggregated_logs.append(
+ {
+ "run_id": run.run_id,
+ "run_type": run.log_type,
+ "entries": entries,
+ }
+ )
+
+ return aggregated_logs
+
+ @telemetry_event("Analytics")
+ async def aanalytics(
+ self,
+ filter_criteria: FilterCriteria,
+ analysis_types: AnalysisTypes,
+ *args,
+ **kwargs,
+ ):
+ run_info = await self.logging_connection.get_run_info(limit=100)
+ run_ids = [info.run_id for info in run_info]
+
+ if not run_ids:
+ return {
+ "analytics_data": "No logs found.",
+ "filtered_logs": {},
+ }
+ logs = await self.logging_connection.get_logs(run_ids=run_ids)
+
+ filters = {}
+ if filter_criteria.filters:
+ for key, value in filter_criteria.filters.items():
+ filters[key] = lambda log, value=value: (
+ any(
+ entry.get("key") == value
+ for entry in log.get("entries", [])
+ )
+ if "entries" in log
+ else log.get("key") == value
+ )
+
+ log_processor = LogProcessor(filters)
+ for log in logs:
+ if "entries" in log and isinstance(log["entries"], list):
+ log_processor.process_log(log)
+ elif "key" in log:
+ log_processor.process_log(log)
+ else:
+ logger.warning(
+ f"Skipping log due to missing or malformed 'entries': {log}"
+ )
+
+ filtered_logs = dict(log_processor.populations.items())
+ results = {"filtered_logs": filtered_logs}
+
+ if analysis_types and analysis_types.analysis_types:
+ for (
+ filter_key,
+ analysis_config,
+ ) in analysis_types.analysis_types.items():
+ if filter_key in filtered_logs:
+ analysis_type = analysis_config[0]
+ if analysis_type == "bar_chart":
+ extract_key = analysis_config[1]
+ results[filter_key] = (
+ AnalysisTypes.generate_bar_chart_data(
+ filtered_logs[filter_key], extract_key
+ )
+ )
+ elif analysis_type == "basic_statistics":
+ extract_key = analysis_config[1]
+ results[filter_key] = (
+ AnalysisTypes.calculate_basic_statistics(
+ filtered_logs[filter_key], extract_key
+ )
+ )
+ elif analysis_type == "percentile":
+ extract_key = analysis_config[1]
+ percentile = int(analysis_config[2])
+ results[filter_key] = (
+ AnalysisTypes.calculate_percentile(
+ filtered_logs[filter_key],
+ extract_key,
+ percentile,
+ )
+ )
+ else:
+ logger.warning(
+ f"Unknown analysis type for filter key '{filter_key}': {analysis_type}"
+ )
+
+ return results
+
+ @telemetry_event("AppSettings")
+ async def aapp_settings(self, *args: Any, **kwargs: Any):
+ prompts = self.providers.prompt.get_all_prompts()
+ return {
+ "config": self.config.to_json(),
+ "prompts": {
+ name: prompt.dict() for name, prompt in prompts.items()
+ },
+ }
+
+ @telemetry_event("UsersOverview")
+ async def ausers_overview(
+ self,
+ user_ids: Optional[list[uuid.UUID]] = None,
+ *args,
+ **kwargs,
+ ):
+ return self.providers.vector_db.get_users_overview(
+ [str(ele) for ele in user_ids] if user_ids else None
+ )
+
+ @telemetry_event("Delete")
+ async def delete(
+ self,
+ keys: list[str],
+ values: list[Union[bool, int, str]],
+ *args,
+ **kwargs,
+ ):
+ metadata = ", ".join(
+ f"{key}={value}" for key, value in zip(keys, values)
+ )
+ values = [str(value) for value in values]
+ logger.info(f"Deleting entries with metadata: {metadata}")
+ ids = self.providers.vector_db.delete_by_metadata(keys, values)
+ if not ids:
+ raise R2RException(
+ status_code=404, message="No entries found for deletion."
+ )
+ for id in ids:
+ self.providers.vector_db.delete_from_documents_overview(id)
+ return f"Documents {ids} deleted successfully."
+
+ @telemetry_event("DocumentsOverview")
+ async def adocuments_overview(
+ self,
+ document_ids: Optional[list[uuid.UUID]] = None,
+ user_ids: Optional[list[uuid.UUID]] = None,
+ *args: Any,
+ **kwargs: Any,
+ ):
+ return self.providers.vector_db.get_documents_overview(
+ filter_document_ids=(
+ [str(ele) for ele in document_ids] if document_ids else None
+ ),
+ filter_user_ids=(
+ [str(ele) for ele in user_ids] if user_ids else None
+ ),
+ )
+
+ @telemetry_event("DocumentChunks")
+ async def document_chunks(
+ self,
+ document_id: uuid.UUID,
+ *args,
+ **kwargs,
+ ):
+ return self.providers.vector_db.get_document_chunks(str(document_id))
+
+ @telemetry_event("UsersOverview")
+ async def users_overview(
+ self,
+ user_ids: Optional[list[uuid.UUID]],
+ *args,
+ **kwargs,
+ ):
+ return self.providers.vector_db.get_users_overview(
+ [str(ele) for ele in user_ids]
+ )
+
+ @telemetry_event("InspectKnowledgeGraph")
+ async def inspect_knowledge_graph(
+ self, limit=10000, *args: Any, **kwargs: Any
+ ):
+ if self.providers.kg is None:
+ raise R2RException(
+ status_code=404, message="Knowledge Graph provider not found."
+ )
+
+ rel_query = f"""
+ MATCH (n1)-[r]->(n2)
+ RETURN n1.id AS subject, type(r) AS relation, n2.id AS object
+ LIMIT {limit}
+ """
+
+ try:
+ with self.providers.kg.client.session(
+ database=self.providers.kg._database
+ ) as session:
+ results = session.run(rel_query)
+ relationships = [
+ (record["subject"], record["relation"], record["object"])
+ for record in results
+ ]
+
+ # Create graph representation and group relationships
+ graph, grouped_relationships = self.process_relationships(
+ relationships
+ )
+
+ # Generate output
+ output = self.generate_output(grouped_relationships, graph)
+
+ return "\n".join(output)
+
+ except Exception as e:
+ logger.error(f"Error printing relationships: {str(e)}")
+ raise R2RException(
+ status_code=500,
+ message=f"An error occurred while fetching relationships: {str(e)}",
+ )
+
+ def process_relationships(
+ self, relationships: List[Tuple[str, str, str]]
+ ) -> Tuple[Dict[str, List[str]], Dict[str, Dict[str, List[str]]]]:
+ graph = defaultdict(list)
+ grouped = defaultdict(lambda: defaultdict(list))
+ for subject, relation, obj in relationships:
+ graph[subject].append(obj)
+ grouped[subject][relation].append(obj)
+ if obj not in graph:
+ graph[obj] = []
+ return dict(graph), dict(grouped)
+
+ def generate_output(
+ self,
+ grouped_relationships: Dict[str, Dict[str, List[str]]],
+ graph: Dict[str, List[str]],
+ ) -> List[str]:
+ output = []
+
+ # Print grouped relationships
+ for subject, relations in grouped_relationships.items():
+ output.append(f"\n== {subject} ==")
+ for relation, objects in relations.items():
+ output.append(f" {relation}:")
+ for obj in objects:
+ output.append(f" - {obj}")
+
+ # Print basic graph statistics
+ output.append("\n== Graph Statistics ==")
+ output.append(f"Number of nodes: {len(graph)}")
+ output.append(
+ f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}"
+ )
+ output.append(
+ f"Number of connected components: {self.count_connected_components(graph)}"
+ )
+
+ # Find central nodes
+ central_nodes = self.get_central_nodes(graph)
+ output.append("\n== Most Central Nodes ==")
+ for node, centrality in central_nodes:
+ output.append(f" {node}: {centrality:.4f}")
+
+ return output
+
+ def count_connected_components(self, graph: Dict[str, List[str]]) -> int:
+ visited = set()
+ components = 0
+
+ def dfs(node):
+ visited.add(node)
+ for neighbor in graph[node]:
+ if neighbor not in visited:
+ dfs(neighbor)
+
+ for node in graph:
+ if node not in visited:
+ dfs(node)
+ components += 1
+
+ return components
+
+ def get_central_nodes(
+ self, graph: Dict[str, List[str]]
+ ) -> List[Tuple[str, float]]:
+ degree = {node: len(neighbors) for node, neighbors in graph.items()}
+ total_nodes = len(graph)
+ centrality = {
+ node: deg / (total_nodes - 1) for node, deg in degree.items()
+ }
+ return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]
+
+ @telemetry_event("AppSettings")
+ async def app_settings(
+ self,
+ *args,
+ **kwargs,
+ ):
+ prompts = self.providers.prompt.get_all_prompts()
+ return {
+ "config": self.config.to_json(),
+ "prompts": {
+ name: prompt.dict() for name, prompt in prompts.items()
+ },
+ }