about summary refs log tree commit diff
path: root/R2R/r2r/base/logging
diff options
context:
space:
mode:
Diffstat (limited to 'R2R/r2r/base/logging')
-rwxr-xr-xR2R/r2r/base/logging/__init__.py0
-rwxr-xr-xR2R/r2r/base/logging/kv_logger.py547
-rwxr-xr-xR2R/r2r/base/logging/log_processor.py196
-rwxr-xr-xR2R/r2r/base/logging/run_manager.py56
4 files changed, 799 insertions, 0 deletions
diff --git a/R2R/r2r/base/logging/__init__.py b/R2R/r2r/base/logging/__init__.py
new file mode 100755
index 00000000..e69de29b
--- /dev/null
+++ b/R2R/r2r/base/logging/__init__.py
diff --git a/R2R/r2r/base/logging/kv_logger.py b/R2R/r2r/base/logging/kv_logger.py
new file mode 100755
index 00000000..2d444e9f
--- /dev/null
+++ b/R2R/r2r/base/logging/kv_logger.py
@@ -0,0 +1,547 @@
+import json
+import logging
+import os
+import uuid
+from abc import abstractmethod
+from datetime import datetime
+from typing import Optional
+
+import asyncpg
+from pydantic import BaseModel
+
+from ..providers.base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class RunInfo(BaseModel):
+    run_id: uuid.UUID
+    log_type: str
+
+
+class LoggingConfig(ProviderConfig):
+    provider: str = "local"
+    log_table: str = "logs"
+    log_info_table: str = "logs_pipeline_info"
+    logging_path: Optional[str] = None
+
+    def validate(self) -> None:
+        pass
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["local", "postgres", "redis"]
+
+
+class KVLoggingProvider(Provider):
+    @abstractmethod
+    async def close(self):
+        pass
+
+    @abstractmethod
+    async def log(self, log_id: uuid.UUID, key: str, value: str):
+        pass
+
+    @abstractmethod
+    async def get_run_info(
+        self,
+        limit: int = 10,
+        log_type_filter: Optional[str] = None,
+    ) -> list[RunInfo]:
+        pass
+
+    @abstractmethod
+    async def get_logs(
+        self, run_ids: list[uuid.UUID], limit_per_run: int
+    ) -> list:
+        pass
+
+
+class LocalKVLoggingProvider(KVLoggingProvider):
+    def __init__(self, config: LoggingConfig):
+        self.log_table = config.log_table
+        self.log_info_table = config.log_info_table
+        self.logging_path = config.logging_path or os.getenv(
+            "LOCAL_DB_PATH", "local.sqlite"
+        )
+        if not self.logging_path:
+            raise ValueError(
+                "Please set the environment variable LOCAL_DB_PATH."
+            )
+        self.conn = None
+        try:
+            import aiosqlite
+
+            self.aiosqlite = aiosqlite
+        except ImportError:
+            raise ImportError(
+                "Please install aiosqlite to use the LocalKVLoggingProvider."
+            )
+
+    async def init(self):
+        self.conn = await self.aiosqlite.connect(self.logging_path)
+        await self.conn.execute(
+            f"""
+            CREATE TABLE IF NOT EXISTS {self.log_table} (
+                timestamp DATETIME,
+                log_id TEXT,
+                key TEXT,
+                value TEXT
+            )
+            """
+        )
+        await self.conn.execute(
+            f"""
+            CREATE TABLE IF NOT EXISTS {self.log_info_table} (
+                timestamp DATETIME,
+                log_id TEXT UNIQUE,
+                log_type TEXT
+            )
+        """
+        )
+        await self.conn.commit()
+
+    async def __aenter__(self):
+        if self.conn is None:
+            await self.init()
+        return self
+
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
+        await self.close()
+
+    async def close(self):
+        if self.conn:
+            await self.conn.close()
+            self.conn = None
+
+    async def log(
+        self,
+        log_id: uuid.UUID,
+        key: str,
+        value: str,
+        is_info_log=False,
+    ):
+        collection = self.log_info_table if is_info_log else self.log_table
+
+        if is_info_log:
+            if "type" not in key:
+                raise ValueError("Info log keys must contain the text 'type'")
+            await self.conn.execute(
+                f"INSERT INTO {collection} (timestamp, log_id, log_type) VALUES (datetime('now'), ?, ?)",
+                (str(log_id), value),
+            )
+        else:
+            await self.conn.execute(
+                f"INSERT INTO {collection} (timestamp, log_id, key, value) VALUES (datetime('now'), ?, ?, ?)",
+                (str(log_id), key, value),
+            )
+        await self.conn.commit()
+
+    async def get_run_info(
+        self, limit: int = 10, log_type_filter: Optional[str] = None
+    ) -> list[RunInfo]:
+        cursor = await self.conn.cursor()
+        query = f'SELECT log_id, log_type FROM "{self.log_info_table}"'
+        conditions = []
+        params = []
+        if log_type_filter:
+            conditions.append("log_type = ?")
+            params.append(log_type_filter)
+        if conditions:
+            query += " WHERE " + " AND ".join(conditions)
+        query += " ORDER BY timestamp DESC LIMIT ?"
+        params.append(limit)
+        await cursor.execute(query, params)
+        rows = await cursor.fetchall()
+        return [
+            RunInfo(run_id=uuid.UUID(row[0]), log_type=row[1]) for row in rows
+        ]
+
+    async def get_logs(
+        self, run_ids: list[uuid.UUID], limit_per_run: int = 10
+    ) -> list:
+        if not run_ids:
+            raise ValueError("No run ids provided.")
+        cursor = await self.conn.cursor()
+        placeholders = ",".join(["?" for _ in run_ids])
+        query = f"""
+        SELECT *
+        FROM (
+            SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn
+            FROM {self.log_table}
+            WHERE log_id IN ({placeholders})
+        )
+        WHERE rn <= ?
+        ORDER BY timestamp DESC
+        """
+        params = [str(ele) for ele in run_ids] + [limit_per_run]
+        await cursor.execute(query, params)
+        rows = await cursor.fetchall()
+        new_rows = []
+        for row in rows:
+            new_rows.append(
+                (row[0], uuid.UUID(row[1]), row[2], row[3], row[4])
+            )
+        return [
+            {desc[0]: row[i] for i, desc in enumerate(cursor.description)}
+            for row in new_rows
+        ]
+
+
+class PostgresLoggingConfig(LoggingConfig):
+    provider: str = "postgres"
+    log_table: str = "logs"
+    log_info_table: str = "logs_pipeline_info"
+
+    def validate(self) -> None:
+        required_env_vars = [
+            "POSTGRES_DBNAME",
+            "POSTGRES_USER",
+            "POSTGRES_PASSWORD",
+            "POSTGRES_HOST",
+            "POSTGRES_PORT",
+        ]
+        for var in required_env_vars:
+            if not os.getenv(var):
+                raise ValueError(f"Environment variable {var} is not set.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["postgres"]
+
+
+class PostgresKVLoggingProvider(KVLoggingProvider):
+    def __init__(self, config: PostgresLoggingConfig):
+        self.log_table = config.log_table
+        self.log_info_table = config.log_info_table
+        self.config = config
+        self.pool = None
+        if not os.getenv("POSTGRES_DBNAME"):
+            raise ValueError(
+                "Please set the environment variable POSTGRES_DBNAME."
+            )
+        if not os.getenv("POSTGRES_USER"):
+            raise ValueError(
+                "Please set the environment variable POSTGRES_USER."
+            )
+        if not os.getenv("POSTGRES_PASSWORD"):
+            raise ValueError(
+                "Please set the environment variable POSTGRES_PASSWORD."
+            )
+        if not os.getenv("POSTGRES_HOST"):
+            raise ValueError(
+                "Please set the environment variable POSTGRES_HOST."
+            )
+        if not os.getenv("POSTGRES_PORT"):
+            raise ValueError(
+                "Please set the environment variable POSTGRES_PORT."
+            )
+
+    async def init(self):
+        self.pool = await asyncpg.create_pool(
+            database=os.getenv("POSTGRES_DBNAME"),
+            user=os.getenv("POSTGRES_USER"),
+            password=os.getenv("POSTGRES_PASSWORD"),
+            host=os.getenv("POSTGRES_HOST"),
+            port=os.getenv("POSTGRES_PORT"),
+            statement_cache_size=0,  # Disable statement caching
+        )
+        async with self.pool.acquire() as conn:
+            await conn.execute(
+                f"""
+                CREATE TABLE IF NOT EXISTS "{self.log_table}" (
+                    timestamp TIMESTAMPTZ,
+                    log_id UUID,
+                    key TEXT,
+                    value TEXT
+                )
+                """
+            )
+            await conn.execute(
+                f"""
+                CREATE TABLE IF NOT EXISTS "{self.log_info_table}" (
+                    timestamp TIMESTAMPTZ,
+                    log_id UUID UNIQUE,
+                    log_type TEXT
+                )
+            """
+            )
+
+    async def __aenter__(self):
+        if self.pool is None:
+            await self.init()
+        return self
+
+    async def __aexit__(self, exc_type, exc_val, exc_tb):
+        await self.close()
+
+    async def close(self):
+        if self.pool:
+            await self.pool.close()
+            self.pool = None
+
+    async def log(
+        self,
+        log_id: uuid.UUID,
+        key: str,
+        value: str,
+        is_info_log=False,
+    ):
+        collection = self.log_info_table if is_info_log else self.log_table
+
+        if is_info_log:
+            if "type" not in key:
+                raise ValueError(
+                    "Info log key must contain the string `type`."
+                )
+            async with self.pool.acquire() as conn:
+                await self.pool.execute(
+                    f'INSERT INTO "{collection}" (timestamp, log_id, log_type) VALUES (NOW(), $1, $2)',
+                    log_id,
+                    value,
+                )
+        else:
+            async with self.pool.acquire() as conn:
+                await conn.execute(
+                    f'INSERT INTO "{collection}" (timestamp, log_id, key, value) VALUES (NOW(), $1, $2, $3)',
+                    log_id,
+                    key,
+                    value,
+                )
+
+    async def get_run_info(
+        self, limit: int = 10, log_type_filter: Optional[str] = None
+    ) -> list[RunInfo]:
+        query = f"SELECT log_id, log_type FROM {self.log_info_table}"
+        conditions = []
+        params = []
+        if log_type_filter:
+            conditions.append("log_type = $1")
+            params.append(log_type_filter)
+        if conditions:
+            query += " WHERE " + " AND ".join(conditions)
+        query += " ORDER BY timestamp DESC LIMIT $2"
+        params.append(limit)
+        async with self.pool.acquire() as conn:
+            rows = await conn.fetch(query, *params)
+            return [
+                RunInfo(run_id=row["log_id"], log_type=row["log_type"])
+                for row in rows
+            ]
+
+    async def get_logs(
+        self, run_ids: list[uuid.UUID], limit_per_run: int = 10
+    ) -> list:
+        if not run_ids:
+            raise ValueError("No run ids provided.")
+
+        placeholders = ",".join([f"${i + 1}" for i in range(len(run_ids))])
+        query = f"""
+        SELECT * FROM (
+            SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn
+            FROM "{self.log_table}"
+            WHERE log_id::text IN ({placeholders})
+        ) sub
+        WHERE sub.rn <= ${len(run_ids) + 1}
+        ORDER BY sub.timestamp DESC
+        """
+        params = [str(run_id) for run_id in run_ids] + [limit_per_run]
+        async with self.pool.acquire() as conn:
+            rows = await conn.fetch(query, *params)
+            return [{key: row[key] for key in row.keys()} for row in rows]
+
+
+class RedisLoggingConfig(LoggingConfig):
+    provider: str = "redis"
+    log_table: str = "logs"
+    log_info_table: str = "logs_pipeline_info"
+
+    def validate(self) -> None:
+        required_env_vars = ["REDIS_CLUSTER_IP", "REDIS_CLUSTER_PORT"]
+        for var in required_env_vars:
+            if not os.getenv(var):
+                raise ValueError(f"Environment variable {var} is not set.")
+
+    @property
+    def supported_providers(self) -> list[str]:
+        return ["redis"]
+
+
+class RedisKVLoggingProvider(KVLoggingProvider):
+    def __init__(self, config: RedisLoggingConfig):
+        logger.info(
+            f"Initializing RedisKVLoggingProvider with config: {config}"
+        )
+
+        if not all(
+            [
+                os.getenv("REDIS_CLUSTER_IP"),
+                os.getenv("REDIS_CLUSTER_PORT"),
+            ]
+        ):
+            raise ValueError(
+                "Please set the environment variables REDIS_CLUSTER_IP and REDIS_CLUSTER_PORT to run `LoggingDatabaseConnection` with `redis`."
+            )
+        try:
+            from redis.asyncio import Redis
+        except ImportError:
+            raise ValueError(
+                "Error, `redis` is not installed. Please install it using `pip install redis`."
+            )
+
+        cluster_ip = os.getenv("REDIS_CLUSTER_IP")
+        port = os.getenv("REDIS_CLUSTER_PORT")
+        self.redis = Redis(host=cluster_ip, port=port, decode_responses=True)
+        self.log_key = config.log_table
+        self.log_info_key = config.log_info_table
+
+    async def __aenter__(self):
+        return self
+
+    async def __aexit__(self, exc_type, exc, tb):
+        await self.close()
+
+    async def close(self):
+        await self.redis.close()
+
+    async def log(
+        self,
+        log_id: uuid.UUID,
+        key: str,
+        value: str,
+        is_info_log=False,
+    ):
+        timestamp = datetime.now().timestamp()
+        log_entry = {
+            "timestamp": timestamp,
+            "log_id": str(log_id),
+            "key": key,
+            "value": value,
+        }
+        if is_info_log:
+            if "type" not in key:
+                raise ValueError("Metadata keys must contain the text 'type'")
+            log_entry["log_type"] = value
+            await self.redis.hset(
+                self.log_info_key, str(log_id), json.dumps(log_entry)
+            )
+            await self.redis.zadd(
+                f"{self.log_info_key}_sorted", {str(log_id): timestamp}
+            )
+        else:
+            await self.redis.lpush(
+                f"{self.log_key}:{str(log_id)}", json.dumps(log_entry)
+            )
+
+    async def get_run_info(
+        self, limit: int = 10, log_type_filter: Optional[str] = None
+    ) -> list[RunInfo]:
+        run_info_list = []
+        start = 0
+        count_per_batch = 100  # Adjust batch size as needed
+
+        while len(run_info_list) < limit:
+            log_ids = await self.redis.zrevrange(
+                f"{self.log_info_key}_sorted",
+                start,
+                start + count_per_batch - 1,
+            )
+            if not log_ids:
+                break  # No more log IDs to process
+
+            start += count_per_batch
+
+            for log_id in log_ids:
+                log_entry = json.loads(
+                    await self.redis.hget(self.log_info_key, log_id)
+                )
+                if log_type_filter:
+                    if log_entry["log_type"] == log_type_filter:
+                        run_info_list.append(
+                            RunInfo(
+                                run_id=uuid.UUID(log_entry["log_id"]),
+                                log_type=log_entry["log_type"],
+                            )
+                        )
+                else:
+                    run_info_list.append(
+                        RunInfo(
+                            run_id=uuid.UUID(log_entry["log_id"]),
+                            log_type=log_entry["log_type"],
+                        )
+                    )
+
+                if len(run_info_list) >= limit:
+                    break
+
+        return run_info_list[:limit]
+
+    async def get_logs(
+        self, run_ids: list[uuid.UUID], limit_per_run: int = 10
+    ) -> list:
+        logs = []
+        for run_id in run_ids:
+            raw_logs = await self.redis.lrange(
+                f"{self.log_key}:{str(run_id)}", 0, limit_per_run - 1
+            )
+            for raw_log in raw_logs:
+                json_log = json.loads(raw_log)
+                json_log["log_id"] = uuid.UUID(json_log["log_id"])
+                logs.append(json_log)
+        return logs
+
+
+class KVLoggingSingleton:
+    _instance = None
+    _is_configured = False
+
+    SUPPORTED_PROVIDERS = {
+        "local": LocalKVLoggingProvider,
+        "postgres": PostgresKVLoggingProvider,
+        "redis": RedisKVLoggingProvider,
+    }
+
+    @classmethod
+    def get_instance(cls):
+        return cls.SUPPORTED_PROVIDERS[cls._config.provider](cls._config)
+
+    @classmethod
+    def configure(
+        cls, logging_config: Optional[LoggingConfig] = LoggingConfig()
+    ):
+        if not cls._is_configured:
+            cls._config = logging_config
+            cls._is_configured = True
+        else:
+            raise Exception("KVLoggingSingleton is already configured.")
+
+    @classmethod
+    async def log(
+        cls,
+        log_id: uuid.UUID,
+        key: str,
+        value: str,
+        is_info_log=False,
+    ):
+        try:
+            async with cls.get_instance() as provider:
+                await provider.log(log_id, key, value, is_info_log=is_info_log)
+
+        except Exception as e:
+            logger.error(f"Error logging data {(log_id, key, value)}: {e}")
+
+    @classmethod
+    async def get_run_info(
+        cls, limit: int = 10, log_type_filter: Optional[str] = None
+    ) -> list[RunInfo]:
+        async with cls.get_instance() as provider:
+            return await provider.get_run_info(
+                limit, log_type_filter=log_type_filter
+            )
+
+    @classmethod
+    async def get_logs(
+        cls, run_ids: list[uuid.UUID], limit_per_run: int = 10
+    ) -> list:
+        async with cls.get_instance() as provider:
+            return await provider.get_logs(run_ids, limit_per_run)
diff --git a/R2R/r2r/base/logging/log_processor.py b/R2R/r2r/base/logging/log_processor.py
new file mode 100755
index 00000000..e85d8de2
--- /dev/null
+++ b/R2R/r2r/base/logging/log_processor.py
@@ -0,0 +1,196 @@
+import contextlib
+import json
+import logging
+import statistics
+from collections import defaultdict
+from typing import Any, Callable, Dict, List, Optional, Sequence
+
+from pydantic import BaseModel
+
+logger = logging.getLogger(__name__)
+
+
+class FilterCriteria(BaseModel):
+    filters: Optional[dict[str, str]] = None
+
+
+class LogProcessor:
+    timestamp_format = "%Y-%m-%d %H:%M:%S"
+
+    def __init__(self, filters: Dict[str, Callable[[Dict[str, Any]], bool]]):
+        self.filters = filters
+        self.populations = {name: [] for name in filters}
+
+    def process_log(self, log: Dict[str, Any]):
+        for name, filter_func in self.filters.items():
+            if filter_func(log):
+                self.populations[name].append(log)
+
+
+class StatisticsCalculator:
+    @staticmethod
+    def calculate_statistics(
+        population: List[Dict[str, Any]],
+        stat_functions: Dict[str, Callable[[List[Dict[str, Any]]], Any]],
+    ) -> Dict[str, Any]:
+        return {
+            name: func(population) for name, func in stat_functions.items()
+        }
+
+
+class DistributionGenerator:
+    @staticmethod
+    def generate_distributions(
+        population: List[Dict[str, Any]],
+        dist_functions: Dict[str, Callable[[List[Dict[str, Any]]], Any]],
+    ) -> Dict[str, Any]:
+        return {
+            name: func(population) for name, func in dist_functions.items()
+        }
+
+
+class VisualizationPreparer:
+    @staticmethod
+    def prepare_visualization_data(
+        data: Dict[str, Any],
+        vis_functions: Dict[str, Callable[[Dict[str, Any]], Any]],
+    ) -> Dict[str, Any]:
+        return {name: func(data) for name, func in vis_functions.items()}
+
+
+class LogAnalyticsConfig:
+    def __init__(self, filters, stat_functions, dist_functions, vis_functions):
+        self.filters = filters
+        self.stat_functions = stat_functions
+        self.dist_functions = dist_functions
+        self.vis_functions = vis_functions
+
+
+class AnalysisTypes(BaseModel):
+    analysis_types: Optional[dict[str, Sequence[str]]] = None
+
+    @staticmethod
+    def generate_bar_chart_data(logs, key):
+        chart_data = {"labels": [], "datasets": []}
+        value_counts = defaultdict(int)
+
+        for log in logs:
+            if "entries" in log:
+                for entry in log["entries"]:
+                    if entry["key"] == key:
+                        value_counts[entry["value"]] += 1
+            elif "key" in log and log["key"] == key:
+                value_counts[log["value"]] += 1
+
+        for value, count in value_counts.items():
+            chart_data["labels"].append(value)
+            chart_data["datasets"].append({"label": key, "data": [count]})
+
+        return chart_data
+
+    @staticmethod
+    def calculate_basic_statistics(logs, key):
+        values = []
+        for log in logs:
+            if log["key"] == "search_results":
+                results = json.loads(log["value"])
+                scores = [
+                    float(json.loads(result)["score"]) for result in results
+                ]
+                values.extend(scores)
+            else:
+                value = log.get("value")
+                if value is not None:
+                    with contextlib.suppress(ValueError):
+                        values.append(float(value))
+
+        if not values:
+            return {
+                "Mean": None,
+                "Median": None,
+                "Mode": None,
+                "Standard Deviation": None,
+                "Variance": None,
+            }
+
+        if len(values) == 1:
+            single_value = round(values[0], 3)
+            return {
+                "Mean": single_value,
+                "Median": single_value,
+                "Mode": single_value,
+                "Standard Deviation": 0,
+                "Variance": 0,
+            }
+
+        mean = round(sum(values) / len(values), 3)
+        median = round(statistics.median(values), 3)
+        mode = (
+            round(statistics.mode(values), 3)
+            if len(set(values)) != len(values)
+            else None
+        )
+        std_dev = round(statistics.stdev(values) if len(values) > 1 else 0, 3)
+        variance = round(
+            statistics.variance(values) if len(values) > 1 else 0, 3
+        )
+
+        return {
+            "Mean": mean,
+            "Median": median,
+            "Mode": mode,
+            "Standard Deviation": std_dev,
+            "Variance": variance,
+        }
+
+    @staticmethod
+    def calculate_percentile(logs, key, percentile):
+        values = []
+        for log in logs:
+            if log["key"] == key:
+                value = log.get("value")
+                if value is not None:
+                    with contextlib.suppress(ValueError):
+                        values.append(float(value))
+
+        if not values:
+            return {"percentile": percentile, "value": None}
+
+        values.sort()
+        index = int((percentile / 100) * (len(values) - 1))
+        return {"percentile": percentile, "value": round(values[index], 3)}
+
+
+class LogAnalytics:
+    def __init__(self, logs: List[Dict[str, Any]], config: LogAnalyticsConfig):
+        self.logs = logs
+        self.log_processor = LogProcessor(config.filters)
+        self.statistics_calculator = StatisticsCalculator()
+        self.distribution_generator = DistributionGenerator()
+        self.visualization_preparer = VisualizationPreparer()
+        self.config = config
+
+    def count_logs(self) -> Dict[str, Any]:
+        """Count the logs for each filter."""
+        return {
+            name: len(population)
+            for name, population in self.log_processor.populations.items()
+        }
+
+    def process_logs(self) -> Dict[str, Any]:
+        for log in self.logs:
+            self.log_processor.process_log(log)
+
+        analytics = {}
+        for name, population in self.log_processor.populations.items():
+            stats = self.statistics_calculator.calculate_statistics(
+                population, self.config.stat_functions
+            )
+            dists = self.distribution_generator.generate_distributions(
+                population, self.config.dist_functions
+            )
+            analytics[name] = {"statistics": stats, "distributions": dists}
+
+        return self.visualization_preparer.prepare_visualization_data(
+            analytics, self.config.vis_functions
+        )
diff --git a/R2R/r2r/base/logging/run_manager.py b/R2R/r2r/base/logging/run_manager.py
new file mode 100755
index 00000000..ac192bca
--- /dev/null
+++ b/R2R/r2r/base/logging/run_manager.py
@@ -0,0 +1,56 @@
+import contextvars
+import uuid
+from contextlib import asynccontextmanager
+from typing import Any
+
+from .kv_logger import KVLoggingSingleton
+
+run_id_var = contextvars.ContextVar("run_id", default=None)
+
+
+class RunManager:
+    def __init__(self, logger: KVLoggingSingleton):
+        self.logger = logger
+        self.run_info = {}
+
+    def generate_run_id(self) -> uuid.UUID:
+        return uuid.uuid4()
+
+    async def set_run_info(self, pipeline_type: str):
+        run_id = run_id_var.get()
+        if run_id is None:
+            run_id = self.generate_run_id()
+            token = run_id_var.set(run_id)
+            self.run_info[run_id] = {"pipeline_type": pipeline_type}
+        else:
+            token = run_id_var.set(run_id)
+        return run_id, token
+
+    async def get_run_info(self):
+        run_id = run_id_var.get()
+        return self.run_info.get(run_id, None)
+
+    async def log_run_info(
+        self, key: str, value: Any, is_info_log: bool = False
+    ):
+        run_id = run_id_var.get()
+        if run_id:
+            await self.logger.log(
+                log_id=run_id, key=key, value=value, is_info_log=is_info_log
+            )
+
+    async def clear_run_info(self, token: contextvars.Token):
+        run_id = run_id_var.get()
+        run_id_var.reset(token)
+        if run_id and run_id in self.run_info:
+            del self.run_info[run_id]
+
+
+@asynccontextmanager
+async def manage_run(run_manager: RunManager, pipeline_type: str):
+    run_id, token = await run_manager.set_run_info(pipeline_type)
+    try:
+        yield run_id
+    finally:
+        # Note: Do not clear the run info to ensure the run ID remains the same
+        run_id_var.reset(token)