import logging
import uuid
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union
from r2r.base import (
AnalysisTypes,
FilterCriteria,
KVLoggingSingleton,
LogProcessor,
R2RException,
RunManager,
)
from r2r.telemetry.telemetry_decorator import telemetry_event
from ..abstractions import R2RPipelines, R2RProviders
from ..assembly.config import R2RConfig
from .base import Service
logger = logging.getLogger(__name__)
class ManagementService(Service):
def __init__(
self,
config: R2RConfig,
providers: R2RProviders,
pipelines: R2RPipelines,
run_manager: RunManager,
logging_connection: KVLoggingSingleton,
):
super().__init__(
config, providers, pipelines, run_manager, logging_connection
)
@telemetry_event("UpdatePrompt")
async def update_prompt(
self,
name: str,
template: Optional[str] = None,
input_types: Optional[dict[str, str]] = {},
*args,
**kwargs,
):
self.providers.prompt.update_prompt(name, template, input_types)
return f"Prompt '{name}' added successfully."
@telemetry_event("Logs")
async def alogs(
self,
log_type_filter: Optional[str] = None,
max_runs_requested: int = 100,
*args: Any,
**kwargs: Any,
):
if self.logging_connection is None:
raise R2RException(
status_code=404, message="Logging provider not found."
)
if (
self.config.app.get("max_logs_per_request", 100)
> max_runs_requested
):
raise R2RException(
status_code=400,
message="Max runs requested exceeds the limit.",
)
run_info = await self.logging_connection.get_run_info(
limit=max_runs_requested,
log_type_filter=log_type_filter,
)
run_ids = [run.run_id for run in run_info]
if len(run_ids) == 0:
return []
logs = await self.logging_connection.get_logs(run_ids)
# Aggregate logs by run_id and include run_type
aggregated_logs = []
for run in run_info:
run_logs = [log for log in logs if log["log_id"] == run.run_id]
entries = [
{"key": log["key"], "value": log["value"]} for log in run_logs
][
::-1
] # Reverse order so that earliest logged values appear first.
aggregated_logs.append(
{
"run_id": run.run_id,
"run_type": run.log_type,
"entries": entries,
}
)
return aggregated_logs
@telemetry_event("Analytics")
async def aanalytics(
self,
filter_criteria: FilterCriteria,
analysis_types: AnalysisTypes,
*args,
**kwargs,
):
run_info = await self.logging_connection.get_run_info(limit=100)
run_ids = [info.run_id for info in run_info]
if not run_ids:
return {
"analytics_data": "No logs found.",
"filtered_logs": {},
}
logs = await self.logging_connection.get_logs(run_ids=run_ids)
filters = {}
if filter_criteria.filters:
for key, value in filter_criteria.filters.items():
filters[key] = lambda log, value=value: (
any(
entry.get("key") == value
for entry in log.get("entries", [])
)
if "entries" in log
else log.get("key") == value
)
log_processor = LogProcessor(filters)
for log in logs:
if "entries" in log and isinstance(log["entries"], list):
log_processor.process_log(log)
elif "key" in log:
log_processor.process_log(log)
else:
logger.warning(
f"Skipping log due to missing or malformed 'entries': {log}"
)
filtered_logs = dict(log_processor.populations.items())
results = {"filtered_logs": filtered_logs}
if analysis_types and analysis_types.analysis_types:
for (
filter_key,
analysis_config,
) in analysis_types.analysis_types.items():
if filter_key in filtered_logs:
analysis_type = analysis_config[0]
if analysis_type == "bar_chart":
extract_key = analysis_config[1]
results[filter_key] = (
AnalysisTypes.generate_bar_chart_data(
filtered_logs[filter_key], extract_key
)
)
elif analysis_type == "basic_statistics":
extract_key = analysis_config[1]
results[filter_key] = (
AnalysisTypes.calculate_basic_statistics(
filtered_logs[filter_key], extract_key
)
)
elif analysis_type == "percentile":
extract_key = analysis_config[1]
percentile = int(analysis_config[2])
results[filter_key] = (
AnalysisTypes.calculate_percentile(
filtered_logs[filter_key],
extract_key,
percentile,
)
)
else:
logger.warning(
f"Unknown analysis type for filter key '{filter_key}': {analysis_type}"
)
return results
@telemetry_event("AppSettings")
async def aapp_settings(self, *args: Any, **kwargs: Any):
prompts = self.providers.prompt.get_all_prompts()
return {
"config": self.config.to_json(),
"prompts": {
name: prompt.dict() for name, prompt in prompts.items()
},
}
@telemetry_event("UsersOverview")
async def ausers_overview(
self,
user_ids: Optional[list[uuid.UUID]] = None,
*args,
**kwargs,
):
return self.providers.vector_db.get_users_overview(
[str(ele) for ele in user_ids] if user_ids else None
)
@telemetry_event("Delete")
async def delete(
self,
keys: list[str],
values: list[Union[bool, int, str]],
*args,
**kwargs,
):
metadata = ", ".join(
f"{key}={value}" for key, value in zip(keys, values)
)
values = [str(value) for value in values]
logger.info(f"Deleting entries with metadata: {metadata}")
ids = self.providers.vector_db.delete_by_metadata(keys, values)
if not ids:
raise R2RException(
status_code=404, message="No entries found for deletion."
)
for id in ids:
self.providers.vector_db.delete_from_documents_overview(id)
return f"Documents {ids} deleted successfully."
@telemetry_event("DocumentsOverview")
async def adocuments_overview(
self,
document_ids: Optional[list[uuid.UUID]] = None,
user_ids: Optional[list[uuid.UUID]] = None,
*args: Any,
**kwargs: Any,
):
return self.providers.vector_db.get_documents_overview(
filter_document_ids=(
[str(ele) for ele in document_ids] if document_ids else None
),
filter_user_ids=(
[str(ele) for ele in user_ids] if user_ids else None
),
)
@telemetry_event("DocumentChunks")
async def document_chunks(
self,
document_id: uuid.UUID,
*args,
**kwargs,
):
return self.providers.vector_db.get_document_chunks(str(document_id))
@telemetry_event("UsersOverview")
async def users_overview(
self,
user_ids: Optional[list[uuid.UUID]],
*args,
**kwargs,
):
return self.providers.vector_db.get_users_overview(
[str(ele) for ele in user_ids]
)
@telemetry_event("InspectKnowledgeGraph")
async def inspect_knowledge_graph(
self, limit=10000, *args: Any, **kwargs: Any
):
if self.providers.kg is None:
raise R2RException(
status_code=404, message="Knowledge Graph provider not found."
)
rel_query = f"""
MATCH (n1)-[r]->(n2)
RETURN n1.id AS subject, type(r) AS relation, n2.id AS object
LIMIT {limit}
"""
try:
with self.providers.kg.client.session(
database=self.providers.kg._database
) as session:
results = session.run(rel_query)
relationships = [
(record["subject"], record["relation"], record["object"])
for record in results
]
# Create graph representation and group relationships
graph, grouped_relationships = self.process_relationships(
relationships
)
# Generate output
output = self.generate_output(grouped_relationships, graph)
return "\n".join(output)
except Exception as e:
logger.error(f"Error printing relationships: {str(e)}")
raise R2RException(
status_code=500,
message=f"An error occurred while fetching relationships: {str(e)}",
)
def process_relationships(
self, relationships: List[Tuple[str, str, str]]
) -> Tuple[Dict[str, List[str]], Dict[str, Dict[str, List[str]]]]:
graph = defaultdict(list)
grouped = defaultdict(lambda: defaultdict(list))
for subject, relation, obj in relationships:
graph[subject].append(obj)
grouped[subject][relation].append(obj)
if obj not in graph:
graph[obj] = []
return dict(graph), dict(grouped)
def generate_output(
self,
grouped_relationships: Dict[str, Dict[str, List[str]]],
graph: Dict[str, List[str]],
) -> List[str]:
output = []
# Print grouped relationships
for subject, relations in grouped_relationships.items():
output.append(f"\n== {subject} ==")
for relation, objects in relations.items():
output.append(f" {relation}:")
for obj in objects:
output.append(f" - {obj}")
# Print basic graph statistics
output.append("\n== Graph Statistics ==")
output.append(f"Number of nodes: {len(graph)}")
output.append(
f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}"
)
output.append(
f"Number of connected components: {self.count_connected_components(graph)}"
)
# Find central nodes
central_nodes = self.get_central_nodes(graph)
output.append("\n== Most Central Nodes ==")
for node, centrality in central_nodes:
output.append(f" {node}: {centrality:.4f}")
return output
def count_connected_components(self, graph: Dict[str, List[str]]) -> int:
visited = set()
components = 0
def dfs(node):
visited.add(node)
for neighbor in graph[node]:
if neighbor not in visited:
dfs(neighbor)
for node in graph:
if node not in visited:
dfs(node)
components += 1
return components
def get_central_nodes(
self, graph: Dict[str, List[str]]
) -> List[Tuple[str, float]]:
degree = {node: len(neighbors) for node, neighbors in graph.items()}
total_nodes = len(graph)
centrality = {
node: deg / (total_nodes - 1) for node, deg in degree.items()
}
return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]
@telemetry_event("AppSettings")
async def app_settings(
self,
*args,
**kwargs,
):
prompts = self.providers.prompt.get_all_prompts()
return {
"config": self.config.to_json(),
"prompts": {
name: prompt.dict() for name, prompt in prompts.items()
},
}