about summary refs log tree commit diff
path: root/R2R/r2r/main/services/management_service.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/main/services/management_service.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are here HEAD master
Diffstat (limited to 'R2R/r2r/main/services/management_service.py')
-rwxr-xr-xR2R/r2r/main/services/management_service.py385
1 files changed, 385 insertions, 0 deletions
diff --git a/R2R/r2r/main/services/management_service.py b/R2R/r2r/main/services/management_service.py
new file mode 100755
index 00000000..00f1f56e
--- /dev/null
+++ b/R2R/r2r/main/services/management_service.py
@@ -0,0 +1,385 @@
+import logging
+import uuid
+from collections import defaultdict
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+from r2r.base import (
+    AnalysisTypes,
+    FilterCriteria,
+    KVLoggingSingleton,
+    LogProcessor,
+    R2RException,
+    RunManager,
+)
+from r2r.telemetry.telemetry_decorator import telemetry_event
+
+from ..abstractions import R2RPipelines, R2RProviders
+from ..assembly.config import R2RConfig
+from .base import Service
+
+logger = logging.getLogger(__name__)
+
+
+class ManagementService(Service):
+    def __init__(
+        self,
+        config: R2RConfig,
+        providers: R2RProviders,
+        pipelines: R2RPipelines,
+        run_manager: RunManager,
+        logging_connection: KVLoggingSingleton,
+    ):
+        super().__init__(
+            config, providers, pipelines, run_manager, logging_connection
+        )
+
+    @telemetry_event("UpdatePrompt")
+    async def update_prompt(
+        self,
+        name: str,
+        template: Optional[str] = None,
+        input_types: Optional[dict[str, str]] = {},
+        *args,
+        **kwargs,
+    ):
+        self.providers.prompt.update_prompt(name, template, input_types)
+        return f"Prompt '{name}' added successfully."
+
+    @telemetry_event("Logs")
+    async def alogs(
+        self,
+        log_type_filter: Optional[str] = None,
+        max_runs_requested: int = 100,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        if self.logging_connection is None:
+            raise R2RException(
+                status_code=404, message="Logging provider not found."
+            )
+        if (
+            self.config.app.get("max_logs_per_request", 100)
+            > max_runs_requested
+        ):
+            raise R2RException(
+                status_code=400,
+                message="Max runs requested exceeds the limit.",
+            )
+
+        run_info = await self.logging_connection.get_run_info(
+            limit=max_runs_requested,
+            log_type_filter=log_type_filter,
+        )
+        run_ids = [run.run_id for run in run_info]
+        if len(run_ids) == 0:
+            return []
+        logs = await self.logging_connection.get_logs(run_ids)
+        # Aggregate logs by run_id and include run_type
+        aggregated_logs = []
+
+        for run in run_info:
+            run_logs = [log for log in logs if log["log_id"] == run.run_id]
+            entries = [
+                {"key": log["key"], "value": log["value"]} for log in run_logs
+            ][
+                ::-1
+            ]  # Reverse order so that earliest logged values appear first.
+            aggregated_logs.append(
+                {
+                    "run_id": run.run_id,
+                    "run_type": run.log_type,
+                    "entries": entries,
+                }
+            )
+
+        return aggregated_logs
+
+    @telemetry_event("Analytics")
+    async def aanalytics(
+        self,
+        filter_criteria: FilterCriteria,
+        analysis_types: AnalysisTypes,
+        *args,
+        **kwargs,
+    ):
+        run_info = await self.logging_connection.get_run_info(limit=100)
+        run_ids = [info.run_id for info in run_info]
+
+        if not run_ids:
+            return {
+                "analytics_data": "No logs found.",
+                "filtered_logs": {},
+            }
+        logs = await self.logging_connection.get_logs(run_ids=run_ids)
+
+        filters = {}
+        if filter_criteria.filters:
+            for key, value in filter_criteria.filters.items():
+                filters[key] = lambda log, value=value: (
+                    any(
+                        entry.get("key") == value
+                        for entry in log.get("entries", [])
+                    )
+                    if "entries" in log
+                    else log.get("key") == value
+                )
+
+        log_processor = LogProcessor(filters)
+        for log in logs:
+            if "entries" in log and isinstance(log["entries"], list):
+                log_processor.process_log(log)
+            elif "key" in log:
+                log_processor.process_log(log)
+            else:
+                logger.warning(
+                    f"Skipping log due to missing or malformed 'entries': {log}"
+                )
+
+        filtered_logs = dict(log_processor.populations.items())
+        results = {"filtered_logs": filtered_logs}
+
+        if analysis_types and analysis_types.analysis_types:
+            for (
+                filter_key,
+                analysis_config,
+            ) in analysis_types.analysis_types.items():
+                if filter_key in filtered_logs:
+                    analysis_type = analysis_config[0]
+                    if analysis_type == "bar_chart":
+                        extract_key = analysis_config[1]
+                        results[filter_key] = (
+                            AnalysisTypes.generate_bar_chart_data(
+                                filtered_logs[filter_key], extract_key
+                            )
+                        )
+                    elif analysis_type == "basic_statistics":
+                        extract_key = analysis_config[1]
+                        results[filter_key] = (
+                            AnalysisTypes.calculate_basic_statistics(
+                                filtered_logs[filter_key], extract_key
+                            )
+                        )
+                    elif analysis_type == "percentile":
+                        extract_key = analysis_config[1]
+                        percentile = int(analysis_config[2])
+                        results[filter_key] = (
+                            AnalysisTypes.calculate_percentile(
+                                filtered_logs[filter_key],
+                                extract_key,
+                                percentile,
+                            )
+                        )
+                    else:
+                        logger.warning(
+                            f"Unknown analysis type for filter key '{filter_key}': {analysis_type}"
+                        )
+
+        return results
+
+    @telemetry_event("AppSettings")
+    async def aapp_settings(self, *args: Any, **kwargs: Any):
+        prompts = self.providers.prompt.get_all_prompts()
+        return {
+            "config": self.config.to_json(),
+            "prompts": {
+                name: prompt.dict() for name, prompt in prompts.items()
+            },
+        }
+
+    @telemetry_event("UsersOverview")
+    async def ausers_overview(
+        self,
+        user_ids: Optional[list[uuid.UUID]] = None,
+        *args,
+        **kwargs,
+    ):
+        return self.providers.vector_db.get_users_overview(
+            [str(ele) for ele in user_ids] if user_ids else None
+        )
+
+    @telemetry_event("Delete")
+    async def delete(
+        self,
+        keys: list[str],
+        values: list[Union[bool, int, str]],
+        *args,
+        **kwargs,
+    ):
+        metadata = ", ".join(
+            f"{key}={value}" for key, value in zip(keys, values)
+        )
+        values = [str(value) for value in values]
+        logger.info(f"Deleting entries with metadata: {metadata}")
+        ids = self.providers.vector_db.delete_by_metadata(keys, values)
+        if not ids:
+            raise R2RException(
+                status_code=404, message="No entries found for deletion."
+            )
+        for id in ids:
+            self.providers.vector_db.delete_from_documents_overview(id)
+        return f"Documents {ids} deleted successfully."
+
+    @telemetry_event("DocumentsOverview")
+    async def adocuments_overview(
+        self,
+        document_ids: Optional[list[uuid.UUID]] = None,
+        user_ids: Optional[list[uuid.UUID]] = None,
+        *args: Any,
+        **kwargs: Any,
+    ):
+        return self.providers.vector_db.get_documents_overview(
+            filter_document_ids=(
+                [str(ele) for ele in document_ids] if document_ids else None
+            ),
+            filter_user_ids=(
+                [str(ele) for ele in user_ids] if user_ids else None
+            ),
+        )
+
+    @telemetry_event("DocumentChunks")
+    async def document_chunks(
+        self,
+        document_id: uuid.UUID,
+        *args,
+        **kwargs,
+    ):
+        return self.providers.vector_db.get_document_chunks(str(document_id))
+
+    @telemetry_event("UsersOverview")
+    async def users_overview(
+        self,
+        user_ids: Optional[list[uuid.UUID]],
+        *args,
+        **kwargs,
+    ):
+        return self.providers.vector_db.get_users_overview(
+            [str(ele) for ele in user_ids]
+        )
+
+    @telemetry_event("InspectKnowledgeGraph")
+    async def inspect_knowledge_graph(
+        self, limit=10000, *args: Any, **kwargs: Any
+    ):
+        if self.providers.kg is None:
+            raise R2RException(
+                status_code=404, message="Knowledge Graph provider not found."
+            )
+
+        rel_query = f"""
+        MATCH (n1)-[r]->(n2)
+        RETURN n1.id AS subject, type(r) AS relation, n2.id AS object
+        LIMIT {limit}
+        """
+
+        try:
+            with self.providers.kg.client.session(
+                database=self.providers.kg._database
+            ) as session:
+                results = session.run(rel_query)
+                relationships = [
+                    (record["subject"], record["relation"], record["object"])
+                    for record in results
+                ]
+
+            # Create graph representation and group relationships
+            graph, grouped_relationships = self.process_relationships(
+                relationships
+            )
+
+            # Generate output
+            output = self.generate_output(grouped_relationships, graph)
+
+            return "\n".join(output)
+
+        except Exception as e:
+            logger.error(f"Error printing relationships: {str(e)}")
+            raise R2RException(
+                status_code=500,
+                message=f"An error occurred while fetching relationships: {str(e)}",
+            )
+
+    def process_relationships(
+        self, relationships: List[Tuple[str, str, str]]
+    ) -> Tuple[Dict[str, List[str]], Dict[str, Dict[str, List[str]]]]:
+        graph = defaultdict(list)
+        grouped = defaultdict(lambda: defaultdict(list))
+        for subject, relation, obj in relationships:
+            graph[subject].append(obj)
+            grouped[subject][relation].append(obj)
+            if obj not in graph:
+                graph[obj] = []
+        return dict(graph), dict(grouped)
+
+    def generate_output(
+        self,
+        grouped_relationships: Dict[str, Dict[str, List[str]]],
+        graph: Dict[str, List[str]],
+    ) -> List[str]:
+        output = []
+
+        # Print grouped relationships
+        for subject, relations in grouped_relationships.items():
+            output.append(f"\n== {subject} ==")
+            for relation, objects in relations.items():
+                output.append(f"  {relation}:")
+                for obj in objects:
+                    output.append(f"    - {obj}")
+
+        # Print basic graph statistics
+        output.append("\n== Graph Statistics ==")
+        output.append(f"Number of nodes: {len(graph)}")
+        output.append(
+            f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}"
+        )
+        output.append(
+            f"Number of connected components: {self.count_connected_components(graph)}"
+        )
+
+        # Find central nodes
+        central_nodes = self.get_central_nodes(graph)
+        output.append("\n== Most Central Nodes ==")
+        for node, centrality in central_nodes:
+            output.append(f"  {node}: {centrality:.4f}")
+
+        return output
+
+    def count_connected_components(self, graph: Dict[str, List[str]]) -> int:
+        visited = set()
+        components = 0
+
+        def dfs(node):
+            visited.add(node)
+            for neighbor in graph[node]:
+                if neighbor not in visited:
+                    dfs(neighbor)
+
+        for node in graph:
+            if node not in visited:
+                dfs(node)
+                components += 1
+
+        return components
+
+    def get_central_nodes(
+        self, graph: Dict[str, List[str]]
+    ) -> List[Tuple[str, float]]:
+        degree = {node: len(neighbors) for node, neighbors in graph.items()}
+        total_nodes = len(graph)
+        centrality = {
+            node: deg / (total_nodes - 1) for node, deg in degree.items()
+        }
+        return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]
+
+    @telemetry_event("AppSettings")
+    async def app_settings(
+        self,
+        *args,
+        **kwargs,
+    ):
+        prompts = self.providers.prompt.get_all_prompts()
+        return {
+            "config": self.config.to_json(),
+            "prompts": {
+                name: prompt.dict() for name, prompt in prompts.items()
+            },
+        }