aboutsummaryrefslogtreecommitdiff
path: root/R2R/r2r/base/logging/kv_logger.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /R2R/r2r/base/logging/kv_logger.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-4a52a71956a8d46fcb7294ac71734504bb09bcc2.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to 'R2R/r2r/base/logging/kv_logger.py')
-rwxr-xr-xR2R/r2r/base/logging/kv_logger.py547
1 files changed, 547 insertions, 0 deletions
diff --git a/R2R/r2r/base/logging/kv_logger.py b/R2R/r2r/base/logging/kv_logger.py
new file mode 100755
index 00000000..2d444e9f
--- /dev/null
+++ b/R2R/r2r/base/logging/kv_logger.py
@@ -0,0 +1,547 @@
+import json
+import logging
+import os
+import uuid
+from abc import abstractmethod
+from datetime import datetime
+from typing import Optional
+
+import asyncpg
+from pydantic import BaseModel
+
+from ..providers.base_provider import Provider, ProviderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class RunInfo(BaseModel):
+ run_id: uuid.UUID
+ log_type: str
+
+
+class LoggingConfig(ProviderConfig):
+ provider: str = "local"
+ log_table: str = "logs"
+ log_info_table: str = "logs_pipeline_info"
+ logging_path: Optional[str] = None
+
+ def validate(self) -> None:
+ pass
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["local", "postgres", "redis"]
+
+
+class KVLoggingProvider(Provider):
+ @abstractmethod
+ async def close(self):
+ pass
+
+ @abstractmethod
+ async def log(self, log_id: uuid.UUID, key: str, value: str):
+ pass
+
+ @abstractmethod
+ async def get_run_info(
+ self,
+ limit: int = 10,
+ log_type_filter: Optional[str] = None,
+ ) -> list[RunInfo]:
+ pass
+
+ @abstractmethod
+ async def get_logs(
+ self, run_ids: list[uuid.UUID], limit_per_run: int
+ ) -> list:
+ pass
+
+
+class LocalKVLoggingProvider(KVLoggingProvider):
+ def __init__(self, config: LoggingConfig):
+ self.log_table = config.log_table
+ self.log_info_table = config.log_info_table
+ self.logging_path = config.logging_path or os.getenv(
+ "LOCAL_DB_PATH", "local.sqlite"
+ )
+ if not self.logging_path:
+ raise ValueError(
+ "Please set the environment variable LOCAL_DB_PATH."
+ )
+ self.conn = None
+ try:
+ import aiosqlite
+
+ self.aiosqlite = aiosqlite
+ except ImportError:
+ raise ImportError(
+ "Please install aiosqlite to use the LocalKVLoggingProvider."
+ )
+
+ async def init(self):
+ self.conn = await self.aiosqlite.connect(self.logging_path)
+ await self.conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS {self.log_table} (
+ timestamp DATETIME,
+ log_id TEXT,
+ key TEXT,
+ value TEXT
+ )
+ """
+ )
+ await self.conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS {self.log_info_table} (
+ timestamp DATETIME,
+ log_id TEXT UNIQUE,
+ log_type TEXT
+ )
+ """
+ )
+ await self.conn.commit()
+
+ async def __aenter__(self):
+ if self.conn is None:
+ await self.init()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
+
+ async def close(self):
+ if self.conn:
+ await self.conn.close()
+ self.conn = None
+
+ async def log(
+ self,
+ log_id: uuid.UUID,
+ key: str,
+ value: str,
+ is_info_log=False,
+ ):
+ collection = self.log_info_table if is_info_log else self.log_table
+
+ if is_info_log:
+ if "type" not in key:
+ raise ValueError("Info log keys must contain the text 'type'")
+ await self.conn.execute(
+ f"INSERT INTO {collection} (timestamp, log_id, log_type) VALUES (datetime('now'), ?, ?)",
+ (str(log_id), value),
+ )
+ else:
+ await self.conn.execute(
+ f"INSERT INTO {collection} (timestamp, log_id, key, value) VALUES (datetime('now'), ?, ?, ?)",
+ (str(log_id), key, value),
+ )
+ await self.conn.commit()
+
+ async def get_run_info(
+ self, limit: int = 10, log_type_filter: Optional[str] = None
+ ) -> list[RunInfo]:
+ cursor = await self.conn.cursor()
+ query = f'SELECT log_id, log_type FROM "{self.log_info_table}"'
+ conditions = []
+ params = []
+ if log_type_filter:
+ conditions.append("log_type = ?")
+ params.append(log_type_filter)
+ if conditions:
+ query += " WHERE " + " AND ".join(conditions)
+ query += " ORDER BY timestamp DESC LIMIT ?"
+ params.append(limit)
+ await cursor.execute(query, params)
+ rows = await cursor.fetchall()
+ return [
+ RunInfo(run_id=uuid.UUID(row[0]), log_type=row[1]) for row in rows
+ ]
+
+ async def get_logs(
+ self, run_ids: list[uuid.UUID], limit_per_run: int = 10
+ ) -> list:
+ if not run_ids:
+ raise ValueError("No run ids provided.")
+ cursor = await self.conn.cursor()
+ placeholders = ",".join(["?" for _ in run_ids])
+ query = f"""
+ SELECT *
+ FROM (
+ SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn
+ FROM {self.log_table}
+ WHERE log_id IN ({placeholders})
+ )
+ WHERE rn <= ?
+ ORDER BY timestamp DESC
+ """
+ params = [str(ele) for ele in run_ids] + [limit_per_run]
+ await cursor.execute(query, params)
+ rows = await cursor.fetchall()
+ new_rows = []
+ for row in rows:
+ new_rows.append(
+ (row[0], uuid.UUID(row[1]), row[2], row[3], row[4])
+ )
+ return [
+ {desc[0]: row[i] for i, desc in enumerate(cursor.description)}
+ for row in new_rows
+ ]
+
+
+class PostgresLoggingConfig(LoggingConfig):
+ provider: str = "postgres"
+ log_table: str = "logs"
+ log_info_table: str = "logs_pipeline_info"
+
+ def validate(self) -> None:
+ required_env_vars = [
+ "POSTGRES_DBNAME",
+ "POSTGRES_USER",
+ "POSTGRES_PASSWORD",
+ "POSTGRES_HOST",
+ "POSTGRES_PORT",
+ ]
+ for var in required_env_vars:
+ if not os.getenv(var):
+ raise ValueError(f"Environment variable {var} is not set.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["postgres"]
+
+
+class PostgresKVLoggingProvider(KVLoggingProvider):
+ def __init__(self, config: PostgresLoggingConfig):
+ self.log_table = config.log_table
+ self.log_info_table = config.log_info_table
+ self.config = config
+ self.pool = None
+ if not os.getenv("POSTGRES_DBNAME"):
+ raise ValueError(
+ "Please set the environment variable POSTGRES_DBNAME."
+ )
+ if not os.getenv("POSTGRES_USER"):
+ raise ValueError(
+ "Please set the environment variable POSTGRES_USER."
+ )
+ if not os.getenv("POSTGRES_PASSWORD"):
+ raise ValueError(
+ "Please set the environment variable POSTGRES_PASSWORD."
+ )
+ if not os.getenv("POSTGRES_HOST"):
+ raise ValueError(
+ "Please set the environment variable POSTGRES_HOST."
+ )
+ if not os.getenv("POSTGRES_PORT"):
+ raise ValueError(
+ "Please set the environment variable POSTGRES_PORT."
+ )
+
+ async def init(self):
+ self.pool = await asyncpg.create_pool(
+ database=os.getenv("POSTGRES_DBNAME"),
+ user=os.getenv("POSTGRES_USER"),
+ password=os.getenv("POSTGRES_PASSWORD"),
+ host=os.getenv("POSTGRES_HOST"),
+ port=os.getenv("POSTGRES_PORT"),
+ statement_cache_size=0, # Disable statement caching
+ )
+ async with self.pool.acquire() as conn:
+ await conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS "{self.log_table}" (
+ timestamp TIMESTAMPTZ,
+ log_id UUID,
+ key TEXT,
+ value TEXT
+ )
+ """
+ )
+ await conn.execute(
+ f"""
+ CREATE TABLE IF NOT EXISTS "{self.log_info_table}" (
+ timestamp TIMESTAMPTZ,
+ log_id UUID UNIQUE,
+ log_type TEXT
+ )
+ """
+ )
+
+ async def __aenter__(self):
+ if self.pool is None:
+ await self.init()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ await self.close()
+
+ async def close(self):
+ if self.pool:
+ await self.pool.close()
+ self.pool = None
+
+ async def log(
+ self,
+ log_id: uuid.UUID,
+ key: str,
+ value: str,
+ is_info_log=False,
+ ):
+ collection = self.log_info_table if is_info_log else self.log_table
+
+ if is_info_log:
+ if "type" not in key:
+ raise ValueError(
+ "Info log key must contain the string `type`."
+ )
+ async with self.pool.acquire() as conn:
+ await self.pool.execute(
+ f'INSERT INTO "{collection}" (timestamp, log_id, log_type) VALUES (NOW(), $1, $2)',
+ log_id,
+ value,
+ )
+ else:
+ async with self.pool.acquire() as conn:
+ await conn.execute(
+ f'INSERT INTO "{collection}" (timestamp, log_id, key, value) VALUES (NOW(), $1, $2, $3)',
+ log_id,
+ key,
+ value,
+ )
+
+ async def get_run_info(
+ self, limit: int = 10, log_type_filter: Optional[str] = None
+ ) -> list[RunInfo]:
+ query = f"SELECT log_id, log_type FROM {self.log_info_table}"
+ conditions = []
+ params = []
+ if log_type_filter:
+ conditions.append("log_type = $1")
+ params.append(log_type_filter)
+ if conditions:
+ query += " WHERE " + " AND ".join(conditions)
+ query += " ORDER BY timestamp DESC LIMIT $2"
+ params.append(limit)
+ async with self.pool.acquire() as conn:
+ rows = await conn.fetch(query, *params)
+ return [
+ RunInfo(run_id=row["log_id"], log_type=row["log_type"])
+ for row in rows
+ ]
+
+ async def get_logs(
+ self, run_ids: list[uuid.UUID], limit_per_run: int = 10
+ ) -> list:
+ if not run_ids:
+ raise ValueError("No run ids provided.")
+
+ placeholders = ",".join([f"${i + 1}" for i in range(len(run_ids))])
+ query = f"""
+ SELECT * FROM (
+ SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn
+ FROM "{self.log_table}"
+ WHERE log_id::text IN ({placeholders})
+ ) sub
+ WHERE sub.rn <= ${len(run_ids) + 1}
+ ORDER BY sub.timestamp DESC
+ """
+ params = [str(run_id) for run_id in run_ids] + [limit_per_run]
+ async with self.pool.acquire() as conn:
+ rows = await conn.fetch(query, *params)
+ return [{key: row[key] for key in row.keys()} for row in rows]
+
+
+class RedisLoggingConfig(LoggingConfig):
+ provider: str = "redis"
+ log_table: str = "logs"
+ log_info_table: str = "logs_pipeline_info"
+
+ def validate(self) -> None:
+ required_env_vars = ["REDIS_CLUSTER_IP", "REDIS_CLUSTER_PORT"]
+ for var in required_env_vars:
+ if not os.getenv(var):
+ raise ValueError(f"Environment variable {var} is not set.")
+
+ @property
+ def supported_providers(self) -> list[str]:
+ return ["redis"]
+
+
+class RedisKVLoggingProvider(KVLoggingProvider):
+ def __init__(self, config: RedisLoggingConfig):
+ logger.info(
+ f"Initializing RedisKVLoggingProvider with config: {config}"
+ )
+
+ if not all(
+ [
+ os.getenv("REDIS_CLUSTER_IP"),
+ os.getenv("REDIS_CLUSTER_PORT"),
+ ]
+ ):
+ raise ValueError(
+ "Please set the environment variables REDIS_CLUSTER_IP and REDIS_CLUSTER_PORT to run `LoggingDatabaseConnection` with `redis`."
+ )
+ try:
+ from redis.asyncio import Redis
+ except ImportError:
+ raise ValueError(
+ "Error, `redis` is not installed. Please install it using `pip install redis`."
+ )
+
+ cluster_ip = os.getenv("REDIS_CLUSTER_IP")
+ port = os.getenv("REDIS_CLUSTER_PORT")
+ self.redis = Redis(host=cluster_ip, port=port, decode_responses=True)
+ self.log_key = config.log_table
+ self.log_info_key = config.log_info_table
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.close()
+
+ async def close(self):
+ await self.redis.close()
+
+ async def log(
+ self,
+ log_id: uuid.UUID,
+ key: str,
+ value: str,
+ is_info_log=False,
+ ):
+ timestamp = datetime.now().timestamp()
+ log_entry = {
+ "timestamp": timestamp,
+ "log_id": str(log_id),
+ "key": key,
+ "value": value,
+ }
+ if is_info_log:
+ if "type" not in key:
+ raise ValueError("Metadata keys must contain the text 'type'")
+ log_entry["log_type"] = value
+ await self.redis.hset(
+ self.log_info_key, str(log_id), json.dumps(log_entry)
+ )
+ await self.redis.zadd(
+ f"{self.log_info_key}_sorted", {str(log_id): timestamp}
+ )
+ else:
+ await self.redis.lpush(
+ f"{self.log_key}:{str(log_id)}", json.dumps(log_entry)
+ )
+
+ async def get_run_info(
+ self, limit: int = 10, log_type_filter: Optional[str] = None
+ ) -> list[RunInfo]:
+ run_info_list = []
+ start = 0
+ count_per_batch = 100 # Adjust batch size as needed
+
+ while len(run_info_list) < limit:
+ log_ids = await self.redis.zrevrange(
+ f"{self.log_info_key}_sorted",
+ start,
+ start + count_per_batch - 1,
+ )
+ if not log_ids:
+ break # No more log IDs to process
+
+ start += count_per_batch
+
+ for log_id in log_ids:
+ log_entry = json.loads(
+ await self.redis.hget(self.log_info_key, log_id)
+ )
+ if log_type_filter:
+ if log_entry["log_type"] == log_type_filter:
+ run_info_list.append(
+ RunInfo(
+ run_id=uuid.UUID(log_entry["log_id"]),
+ log_type=log_entry["log_type"],
+ )
+ )
+ else:
+ run_info_list.append(
+ RunInfo(
+ run_id=uuid.UUID(log_entry["log_id"]),
+ log_type=log_entry["log_type"],
+ )
+ )
+
+ if len(run_info_list) >= limit:
+ break
+
+ return run_info_list[:limit]
+
+ async def get_logs(
+ self, run_ids: list[uuid.UUID], limit_per_run: int = 10
+ ) -> list:
+ logs = []
+ for run_id in run_ids:
+ raw_logs = await self.redis.lrange(
+ f"{self.log_key}:{str(run_id)}", 0, limit_per_run - 1
+ )
+ for raw_log in raw_logs:
+ json_log = json.loads(raw_log)
+ json_log["log_id"] = uuid.UUID(json_log["log_id"])
+ logs.append(json_log)
+ return logs
+
+
+class KVLoggingSingleton:
+ _instance = None
+ _is_configured = False
+
+ SUPPORTED_PROVIDERS = {
+ "local": LocalKVLoggingProvider,
+ "postgres": PostgresKVLoggingProvider,
+ "redis": RedisKVLoggingProvider,
+ }
+
+ @classmethod
+ def get_instance(cls):
+ return cls.SUPPORTED_PROVIDERS[cls._config.provider](cls._config)
+
+ @classmethod
+ def configure(
+ cls, logging_config: Optional[LoggingConfig] = LoggingConfig()
+ ):
+ if not cls._is_configured:
+ cls._config = logging_config
+ cls._is_configured = True
+ else:
+ raise Exception("KVLoggingSingleton is already configured.")
+
+ @classmethod
+ async def log(
+ cls,
+ log_id: uuid.UUID,
+ key: str,
+ value: str,
+ is_info_log=False,
+ ):
+ try:
+ async with cls.get_instance() as provider:
+ await provider.log(log_id, key, value, is_info_log=is_info_log)
+
+ except Exception as e:
+ logger.error(f"Error logging data {(log_id, key, value)}: {e}")
+
+ @classmethod
+ async def get_run_info(
+ cls, limit: int = 10, log_type_filter: Optional[str] = None
+ ) -> list[RunInfo]:
+ async with cls.get_instance() as provider:
+ return await provider.get_run_info(
+ limit, log_type_filter=log_type_filter
+ )
+
+ @classmethod
+ async def get_logs(
+ cls, run_ids: list[uuid.UUID], limit_per_run: int = 10
+ ) -> list:
+ async with cls.get_instance() as provider:
+ return await provider.get_logs(run_ids, limit_per_run)