import json import logging import os import uuid from abc import abstractmethod from datetime import datetime from typing import Optional import asyncpg from pydantic import BaseModel from ..providers.base_provider import Provider, ProviderConfig logger = logging.getLogger(__name__) class RunInfo(BaseModel): run_id: uuid.UUID log_type: str class LoggingConfig(ProviderConfig): provider: str = "local" log_table: str = "logs" log_info_table: str = "logs_pipeline_info" logging_path: Optional[str] = None def validate(self) -> None: pass @property def supported_providers(self) -> list[str]: return ["local", "postgres", "redis"] class KVLoggingProvider(Provider): @abstractmethod async def close(self): pass @abstractmethod async def log(self, log_id: uuid.UUID, key: str, value: str): pass @abstractmethod async def get_run_info( self, limit: int = 10, log_type_filter: Optional[str] = None, ) -> list[RunInfo]: pass @abstractmethod async def get_logs( self, run_ids: list[uuid.UUID], limit_per_run: int ) -> list: pass class LocalKVLoggingProvider(KVLoggingProvider): def __init__(self, config: LoggingConfig): self.log_table = config.log_table self.log_info_table = config.log_info_table self.logging_path = config.logging_path or os.getenv( "LOCAL_DB_PATH", "local.sqlite" ) if not self.logging_path: raise ValueError( "Please set the environment variable LOCAL_DB_PATH." ) self.conn = None try: import aiosqlite self.aiosqlite = aiosqlite except ImportError: raise ImportError( "Please install aiosqlite to use the LocalKVLoggingProvider." ) async def init(self): self.conn = await self.aiosqlite.connect(self.logging_path) await self.conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.log_table} ( timestamp DATETIME, log_id TEXT, key TEXT, value TEXT ) """ ) await self.conn.execute( f""" CREATE TABLE IF NOT EXISTS {self.log_info_table} ( timestamp DATETIME, log_id TEXT UNIQUE, log_type TEXT ) """ ) await self.conn.commit() async def __aenter__(self): if self.conn is None: await self.init() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() async def close(self): if self.conn: await self.conn.close() self.conn = None async def log( self, log_id: uuid.UUID, key: str, value: str, is_info_log=False, ): collection = self.log_info_table if is_info_log else self.log_table if is_info_log: if "type" not in key: raise ValueError("Info log keys must contain the text 'type'") await self.conn.execute( f"INSERT INTO {collection} (timestamp, log_id, log_type) VALUES (datetime('now'), ?, ?)", (str(log_id), value), ) else: await self.conn.execute( f"INSERT INTO {collection} (timestamp, log_id, key, value) VALUES (datetime('now'), ?, ?, ?)", (str(log_id), key, value), ) await self.conn.commit() async def get_run_info( self, limit: int = 10, log_type_filter: Optional[str] = None ) -> list[RunInfo]: cursor = await self.conn.cursor() query = f'SELECT log_id, log_type FROM "{self.log_info_table}"' conditions = [] params = [] if log_type_filter: conditions.append("log_type = ?") params.append(log_type_filter) if conditions: query += " WHERE " + " AND ".join(conditions) query += " ORDER BY timestamp DESC LIMIT ?" params.append(limit) await cursor.execute(query, params) rows = await cursor.fetchall() return [ RunInfo(run_id=uuid.UUID(row[0]), log_type=row[1]) for row in rows ] async def get_logs( self, run_ids: list[uuid.UUID], limit_per_run: int = 10 ) -> list: if not run_ids: raise ValueError("No run ids provided.") cursor = await self.conn.cursor() placeholders = ",".join(["?" for _ in run_ids]) query = f""" SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn FROM {self.log_table} WHERE log_id IN ({placeholders}) ) WHERE rn <= ? ORDER BY timestamp DESC """ params = [str(ele) for ele in run_ids] + [limit_per_run] await cursor.execute(query, params) rows = await cursor.fetchall() new_rows = [] for row in rows: new_rows.append( (row[0], uuid.UUID(row[1]), row[2], row[3], row[4]) ) return [ {desc[0]: row[i] for i, desc in enumerate(cursor.description)} for row in new_rows ] class PostgresLoggingConfig(LoggingConfig): provider: str = "postgres" log_table: str = "logs" log_info_table: str = "logs_pipeline_info" def validate(self) -> None: required_env_vars = [ "POSTGRES_DBNAME", "POSTGRES_USER", "POSTGRES_PASSWORD", "POSTGRES_HOST", "POSTGRES_PORT", ] for var in required_env_vars: if not os.getenv(var): raise ValueError(f"Environment variable {var} is not set.") @property def supported_providers(self) -> list[str]: return ["postgres"] class PostgresKVLoggingProvider(KVLoggingProvider): def __init__(self, config: PostgresLoggingConfig): self.log_table = config.log_table self.log_info_table = config.log_info_table self.config = config self.pool = None if not os.getenv("POSTGRES_DBNAME"): raise ValueError( "Please set the environment variable POSTGRES_DBNAME." ) if not os.getenv("POSTGRES_USER"): raise ValueError( "Please set the environment variable POSTGRES_USER." ) if not os.getenv("POSTGRES_PASSWORD"): raise ValueError( "Please set the environment variable POSTGRES_PASSWORD." ) if not os.getenv("POSTGRES_HOST"): raise ValueError( "Please set the environment variable POSTGRES_HOST." ) if not os.getenv("POSTGRES_PORT"): raise ValueError( "Please set the environment variable POSTGRES_PORT." ) async def init(self): self.pool = await asyncpg.create_pool( database=os.getenv("POSTGRES_DBNAME"), user=os.getenv("POSTGRES_USER"), password=os.getenv("POSTGRES_PASSWORD"), host=os.getenv("POSTGRES_HOST"), port=os.getenv("POSTGRES_PORT"), statement_cache_size=0, # Disable statement caching ) async with self.pool.acquire() as conn: await conn.execute( f""" CREATE TABLE IF NOT EXISTS "{self.log_table}" ( timestamp TIMESTAMPTZ, log_id UUID, key TEXT, value TEXT ) """ ) await conn.execute( f""" CREATE TABLE IF NOT EXISTS "{self.log_info_table}" ( timestamp TIMESTAMPTZ, log_id UUID UNIQUE, log_type TEXT ) """ ) async def __aenter__(self): if self.pool is None: await self.init() return self async def __aexit__(self, exc_type, exc_val, exc_tb): await self.close() async def close(self): if self.pool: await self.pool.close() self.pool = None async def log( self, log_id: uuid.UUID, key: str, value: str, is_info_log=False, ): collection = self.log_info_table if is_info_log else self.log_table if is_info_log: if "type" not in key: raise ValueError( "Info log key must contain the string `type`." ) async with self.pool.acquire() as conn: await self.pool.execute( f'INSERT INTO "{collection}" (timestamp, log_id, log_type) VALUES (NOW(), $1, $2)', log_id, value, ) else: async with self.pool.acquire() as conn: await conn.execute( f'INSERT INTO "{collection}" (timestamp, log_id, key, value) VALUES (NOW(), $1, $2, $3)', log_id, key, value, ) async def get_run_info( self, limit: int = 10, log_type_filter: Optional[str] = None ) -> list[RunInfo]: query = f"SELECT log_id, log_type FROM {self.log_info_table}" conditions = [] params = [] if log_type_filter: conditions.append("log_type = $1") params.append(log_type_filter) if conditions: query += " WHERE " + " AND ".join(conditions) query += " ORDER BY timestamp DESC LIMIT $2" params.append(limit) async with self.pool.acquire() as conn: rows = await conn.fetch(query, *params) return [ RunInfo(run_id=row["log_id"], log_type=row["log_type"]) for row in rows ] async def get_logs( self, run_ids: list[uuid.UUID], limit_per_run: int = 10 ) -> list: if not run_ids: raise ValueError("No run ids provided.") placeholders = ",".join([f"${i + 1}" for i in range(len(run_ids))]) query = f""" SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn FROM "{self.log_table}" WHERE log_id::text IN ({placeholders}) ) sub WHERE sub.rn <= ${len(run_ids) + 1} ORDER BY sub.timestamp DESC """ params = [str(run_id) for run_id in run_ids] + [limit_per_run] async with self.pool.acquire() as conn: rows = await conn.fetch(query, *params) return [{key: row[key] for key in row.keys()} for row in rows] class RedisLoggingConfig(LoggingConfig): provider: str = "redis" log_table: str = "logs" log_info_table: str = "logs_pipeline_info" def validate(self) -> None: required_env_vars = ["REDIS_CLUSTER_IP", "REDIS_CLUSTER_PORT"] for var in required_env_vars: if not os.getenv(var): raise ValueError(f"Environment variable {var} is not set.") @property def supported_providers(self) -> list[str]: return ["redis"] class RedisKVLoggingProvider(KVLoggingProvider): def __init__(self, config: RedisLoggingConfig): logger.info( f"Initializing RedisKVLoggingProvider with config: {config}" ) if not all( [ os.getenv("REDIS_CLUSTER_IP"), os.getenv("REDIS_CLUSTER_PORT"), ] ): raise ValueError( "Please set the environment variables REDIS_CLUSTER_IP and REDIS_CLUSTER_PORT to run `LoggingDatabaseConnection` with `redis`." ) try: from redis.asyncio import Redis except ImportError: raise ValueError( "Error, `redis` is not installed. Please install it using `pip install redis`." ) cluster_ip = os.getenv("REDIS_CLUSTER_IP") port = os.getenv("REDIS_CLUSTER_PORT") self.redis = Redis(host=cluster_ip, port=port, decode_responses=True) self.log_key = config.log_table self.log_info_key = config.log_info_table async def __aenter__(self): return self async def __aexit__(self, exc_type, exc, tb): await self.close() async def close(self): await self.redis.close() async def log( self, log_id: uuid.UUID, key: str, value: str, is_info_log=False, ): timestamp = datetime.now().timestamp() log_entry = { "timestamp": timestamp, "log_id": str(log_id), "key": key, "value": value, } if is_info_log: if "type" not in key: raise ValueError("Metadata keys must contain the text 'type'") log_entry["log_type"] = value await self.redis.hset( self.log_info_key, str(log_id), json.dumps(log_entry) ) await self.redis.zadd( f"{self.log_info_key}_sorted", {str(log_id): timestamp} ) else: await self.redis.lpush( f"{self.log_key}:{str(log_id)}", json.dumps(log_entry) ) async def get_run_info( self, limit: int = 10, log_type_filter: Optional[str] = None ) -> list[RunInfo]: run_info_list = [] start = 0 count_per_batch = 100 # Adjust batch size as needed while len(run_info_list) < limit: log_ids = await self.redis.zrevrange( f"{self.log_info_key}_sorted", start, start + count_per_batch - 1, ) if not log_ids: break # No more log IDs to process start += count_per_batch for log_id in log_ids: log_entry = json.loads( await self.redis.hget(self.log_info_key, log_id) ) if log_type_filter: if log_entry["log_type"] == log_type_filter: run_info_list.append( RunInfo( run_id=uuid.UUID(log_entry["log_id"]), log_type=log_entry["log_type"], ) ) else: run_info_list.append( RunInfo( run_id=uuid.UUID(log_entry["log_id"]), log_type=log_entry["log_type"], ) ) if len(run_info_list) >= limit: break return run_info_list[:limit] async def get_logs( self, run_ids: list[uuid.UUID], limit_per_run: int = 10 ) -> list: logs = [] for run_id in run_ids: raw_logs = await self.redis.lrange( f"{self.log_key}:{str(run_id)}", 0, limit_per_run - 1 ) for raw_log in raw_logs: json_log = json.loads(raw_log) json_log["log_id"] = uuid.UUID(json_log["log_id"]) logs.append(json_log) return logs class KVLoggingSingleton: _instance = None _is_configured = False SUPPORTED_PROVIDERS = { "local": LocalKVLoggingProvider, "postgres": PostgresKVLoggingProvider, "redis": RedisKVLoggingProvider, } @classmethod def get_instance(cls): return cls.SUPPORTED_PROVIDERS[cls._config.provider](cls._config) @classmethod def configure( cls, logging_config: Optional[LoggingConfig] = LoggingConfig() ): if not cls._is_configured: cls._config = logging_config cls._is_configured = True else: raise Exception("KVLoggingSingleton is already configured.") @classmethod async def log( cls, log_id: uuid.UUID, key: str, value: str, is_info_log=False, ): try: async with cls.get_instance() as provider: await provider.log(log_id, key, value, is_info_log=is_info_log) except Exception as e: logger.error(f"Error logging data {(log_id, key, value)}: {e}") @classmethod async def get_run_info( cls, limit: int = 10, log_type_filter: Optional[str] = None ) -> list[RunInfo]: async with cls.get_instance() as provider: return await provider.get_run_info( limit, log_type_filter=log_type_filter ) @classmethod async def get_logs( cls, run_ids: list[uuid.UUID], limit_per_run: int = 10 ) -> list: async with cls.get_instance() as provider: return await provider.get_logs(run_ids, limit_per_run)