aboutsummaryrefslogtreecommitdiff
import logging
import uuid
from collections import defaultdict
from typing import Any, Dict, List, Optional, Tuple, Union

from r2r.base import (
    AnalysisTypes,
    FilterCriteria,
    KVLoggingSingleton,
    LogProcessor,
    R2RException,
    RunManager,
)
from r2r.telemetry.telemetry_decorator import telemetry_event

from ..abstractions import R2RPipelines, R2RProviders
from ..assembly.config import R2RConfig
from .base import Service

logger = logging.getLogger(__name__)


class ManagementService(Service):
    def __init__(
        self,
        config: R2RConfig,
        providers: R2RProviders,
        pipelines: R2RPipelines,
        run_manager: RunManager,
        logging_connection: KVLoggingSingleton,
    ):
        super().__init__(
            config, providers, pipelines, run_manager, logging_connection
        )

    @telemetry_event("UpdatePrompt")
    async def update_prompt(
        self,
        name: str,
        template: Optional[str] = None,
        input_types: Optional[dict[str, str]] = {},
        *args,
        **kwargs,
    ):
        self.providers.prompt.update_prompt(name, template, input_types)
        return f"Prompt '{name}' added successfully."

    @telemetry_event("Logs")
    async def alogs(
        self,
        log_type_filter: Optional[str] = None,
        max_runs_requested: int = 100,
        *args: Any,
        **kwargs: Any,
    ):
        if self.logging_connection is None:
            raise R2RException(
                status_code=404, message="Logging provider not found."
            )
        if (
            self.config.app.get("max_logs_per_request", 100)
            > max_runs_requested
        ):
            raise R2RException(
                status_code=400,
                message="Max runs requested exceeds the limit.",
            )

        run_info = await self.logging_connection.get_run_info(
            limit=max_runs_requested,
            log_type_filter=log_type_filter,
        )
        run_ids = [run.run_id for run in run_info]
        if len(run_ids) == 0:
            return []
        logs = await self.logging_connection.get_logs(run_ids)
        # Aggregate logs by run_id and include run_type
        aggregated_logs = []

        for run in run_info:
            run_logs = [log for log in logs if log["log_id"] == run.run_id]
            entries = [
                {"key": log["key"], "value": log["value"]} for log in run_logs
            ][
                ::-1
            ]  # Reverse order so that earliest logged values appear first.
            aggregated_logs.append(
                {
                    "run_id": run.run_id,
                    "run_type": run.log_type,
                    "entries": entries,
                }
            )

        return aggregated_logs

    @telemetry_event("Analytics")
    async def aanalytics(
        self,
        filter_criteria: FilterCriteria,
        analysis_types: AnalysisTypes,
        *args,
        **kwargs,
    ):
        run_info = await self.logging_connection.get_run_info(limit=100)
        run_ids = [info.run_id for info in run_info]

        if not run_ids:
            return {
                "analytics_data": "No logs found.",
                "filtered_logs": {},
            }
        logs = await self.logging_connection.get_logs(run_ids=run_ids)

        filters = {}
        if filter_criteria.filters:
            for key, value in filter_criteria.filters.items():
                filters[key] = lambda log, value=value: (
                    any(
                        entry.get("key") == value
                        for entry in log.get("entries", [])
                    )
                    if "entries" in log
                    else log.get("key") == value
                )

        log_processor = LogProcessor(filters)
        for log in logs:
            if "entries" in log and isinstance(log["entries"], list):
                log_processor.process_log(log)
            elif "key" in log:
                log_processor.process_log(log)
            else:
                logger.warning(
                    f"Skipping log due to missing or malformed 'entries': {log}"
                )

        filtered_logs = dict(log_processor.populations.items())
        results = {"filtered_logs": filtered_logs}

        if analysis_types and analysis_types.analysis_types:
            for (
                filter_key,
                analysis_config,
            ) in analysis_types.analysis_types.items():
                if filter_key in filtered_logs:
                    analysis_type = analysis_config[0]
                    if analysis_type == "bar_chart":
                        extract_key = analysis_config[1]
                        results[filter_key] = (
                            AnalysisTypes.generate_bar_chart_data(
                                filtered_logs[filter_key], extract_key
                            )
                        )
                    elif analysis_type == "basic_statistics":
                        extract_key = analysis_config[1]
                        results[filter_key] = (
                            AnalysisTypes.calculate_basic_statistics(
                                filtered_logs[filter_key], extract_key
                            )
                        )
                    elif analysis_type == "percentile":
                        extract_key = analysis_config[1]
                        percentile = int(analysis_config[2])
                        results[filter_key] = (
                            AnalysisTypes.calculate_percentile(
                                filtered_logs[filter_key],
                                extract_key,
                                percentile,
                            )
                        )
                    else:
                        logger.warning(
                            f"Unknown analysis type for filter key '{filter_key}': {analysis_type}"
                        )

        return results

    @telemetry_event("AppSettings")
    async def aapp_settings(self, *args: Any, **kwargs: Any):
        prompts = self.providers.prompt.get_all_prompts()
        return {
            "config": self.config.to_json(),
            "prompts": {
                name: prompt.dict() for name, prompt in prompts.items()
            },
        }

    @telemetry_event("UsersOverview")
    async def ausers_overview(
        self,
        user_ids: Optional[list[uuid.UUID]] = None,
        *args,
        **kwargs,
    ):
        return self.providers.vector_db.get_users_overview(
            [str(ele) for ele in user_ids] if user_ids else None
        )

    @telemetry_event("Delete")
    async def delete(
        self,
        keys: list[str],
        values: list[Union[bool, int, str]],
        *args,
        **kwargs,
    ):
        metadata = ", ".join(
            f"{key}={value}" for key, value in zip(keys, values)
        )
        values = [str(value) for value in values]
        logger.info(f"Deleting entries with metadata: {metadata}")
        ids = self.providers.vector_db.delete_by_metadata(keys, values)
        if not ids:
            raise R2RException(
                status_code=404, message="No entries found for deletion."
            )
        for id in ids:
            self.providers.vector_db.delete_from_documents_overview(id)
        return f"Documents {ids} deleted successfully."

    @telemetry_event("DocumentsOverview")
    async def adocuments_overview(
        self,
        document_ids: Optional[list[uuid.UUID]] = None,
        user_ids: Optional[list[uuid.UUID]] = None,
        *args: Any,
        **kwargs: Any,
    ):
        return self.providers.vector_db.get_documents_overview(
            filter_document_ids=(
                [str(ele) for ele in document_ids] if document_ids else None
            ),
            filter_user_ids=(
                [str(ele) for ele in user_ids] if user_ids else None
            ),
        )

    @telemetry_event("DocumentChunks")
    async def document_chunks(
        self,
        document_id: uuid.UUID,
        *args,
        **kwargs,
    ):
        return self.providers.vector_db.get_document_chunks(str(document_id))

    @telemetry_event("UsersOverview")
    async def users_overview(
        self,
        user_ids: Optional[list[uuid.UUID]],
        *args,
        **kwargs,
    ):
        return self.providers.vector_db.get_users_overview(
            [str(ele) for ele in user_ids]
        )

    @telemetry_event("InspectKnowledgeGraph")
    async def inspect_knowledge_graph(
        self, limit=10000, *args: Any, **kwargs: Any
    ):
        if self.providers.kg is None:
            raise R2RException(
                status_code=404, message="Knowledge Graph provider not found."
            )

        rel_query = f"""
        MATCH (n1)-[r]->(n2)
        RETURN n1.id AS subject, type(r) AS relation, n2.id AS object
        LIMIT {limit}
        """

        try:
            with self.providers.kg.client.session(
                database=self.providers.kg._database
            ) as session:
                results = session.run(rel_query)
                relationships = [
                    (record["subject"], record["relation"], record["object"])
                    for record in results
                ]

            # Create graph representation and group relationships
            graph, grouped_relationships = self.process_relationships(
                relationships
            )

            # Generate output
            output = self.generate_output(grouped_relationships, graph)

            return "\n".join(output)

        except Exception as e:
            logger.error(f"Error printing relationships: {str(e)}")
            raise R2RException(
                status_code=500,
                message=f"An error occurred while fetching relationships: {str(e)}",
            )

    def process_relationships(
        self, relationships: List[Tuple[str, str, str]]
    ) -> Tuple[Dict[str, List[str]], Dict[str, Dict[str, List[str]]]]:
        graph = defaultdict(list)
        grouped = defaultdict(lambda: defaultdict(list))
        for subject, relation, obj in relationships:
            graph[subject].append(obj)
            grouped[subject][relation].append(obj)
            if obj not in graph:
                graph[obj] = []
        return dict(graph), dict(grouped)

    def generate_output(
        self,
        grouped_relationships: Dict[str, Dict[str, List[str]]],
        graph: Dict[str, List[str]],
    ) -> List[str]:
        output = []

        # Print grouped relationships
        for subject, relations in grouped_relationships.items():
            output.append(f"\n== {subject} ==")
            for relation, objects in relations.items():
                output.append(f"  {relation}:")
                for obj in objects:
                    output.append(f"    - {obj}")

        # Print basic graph statistics
        output.append("\n== Graph Statistics ==")
        output.append(f"Number of nodes: {len(graph)}")
        output.append(
            f"Number of edges: {sum(len(neighbors) for neighbors in graph.values())}"
        )
        output.append(
            f"Number of connected components: {self.count_connected_components(graph)}"
        )

        # Find central nodes
        central_nodes = self.get_central_nodes(graph)
        output.append("\n== Most Central Nodes ==")
        for node, centrality in central_nodes:
            output.append(f"  {node}: {centrality:.4f}")

        return output

    def count_connected_components(self, graph: Dict[str, List[str]]) -> int:
        visited = set()
        components = 0

        def dfs(node):
            visited.add(node)
            for neighbor in graph[node]:
                if neighbor not in visited:
                    dfs(neighbor)

        for node in graph:
            if node not in visited:
                dfs(node)
                components += 1

        return components

    def get_central_nodes(
        self, graph: Dict[str, List[str]]
    ) -> List[Tuple[str, float]]:
        degree = {node: len(neighbors) for node, neighbors in graph.items()}
        total_nodes = len(graph)
        centrality = {
            node: deg / (total_nodes - 1) for node, deg in degree.items()
        }
        return sorted(centrality.items(), key=lambda x: x[1], reverse=True)[:5]

    @telemetry_event("AppSettings")
    async def app_settings(
        self,
        *args,
        **kwargs,
    ):
        prompts = self.providers.prompt.get_all_prompts()
        return {
            "config": self.config.to_json(),
            "prompts": {
                name: prompt.dict() for name, prompt in prompts.items()
            },
        }