import json
import logging
import os
import uuid
from abc import abstractmethod
from datetime import datetime
from typing import Optional
import asyncpg
from pydantic import BaseModel
from ..providers.base_provider import Provider, ProviderConfig
logger = logging.getLogger(__name__)
class RunInfo(BaseModel):
run_id: uuid.UUID
log_type: str
class LoggingConfig(ProviderConfig):
provider: str = "local"
log_table: str = "logs"
log_info_table: str = "logs_pipeline_info"
logging_path: Optional[str] = None
def validate(self) -> None:
pass
@property
def supported_providers(self) -> list[str]:
return ["local", "postgres", "redis"]
class KVLoggingProvider(Provider):
@abstractmethod
async def close(self):
pass
@abstractmethod
async def log(self, log_id: uuid.UUID, key: str, value: str):
pass
@abstractmethod
async def get_run_info(
self,
limit: int = 10,
log_type_filter: Optional[str] = None,
) -> list[RunInfo]:
pass
@abstractmethod
async def get_logs(
self, run_ids: list[uuid.UUID], limit_per_run: int
) -> list:
pass
class LocalKVLoggingProvider(KVLoggingProvider):
def __init__(self, config: LoggingConfig):
self.log_table = config.log_table
self.log_info_table = config.log_info_table
self.logging_path = config.logging_path or os.getenv(
"LOCAL_DB_PATH", "local.sqlite"
)
if not self.logging_path:
raise ValueError(
"Please set the environment variable LOCAL_DB_PATH."
)
self.conn = None
try:
import aiosqlite
self.aiosqlite = aiosqlite
except ImportError:
raise ImportError(
"Please install aiosqlite to use the LocalKVLoggingProvider."
)
async def init(self):
self.conn = await self.aiosqlite.connect(self.logging_path)
await self.conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.log_table} (
timestamp DATETIME,
log_id TEXT,
key TEXT,
value TEXT
)
"""
)
await self.conn.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.log_info_table} (
timestamp DATETIME,
log_id TEXT UNIQUE,
log_type TEXT
)
"""
)
await self.conn.commit()
async def __aenter__(self):
if self.conn is None:
await self.init()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async def close(self):
if self.conn:
await self.conn.close()
self.conn = None
async def log(
self,
log_id: uuid.UUID,
key: str,
value: str,
is_info_log=False,
):
collection = self.log_info_table if is_info_log else self.log_table
if is_info_log:
if "type" not in key:
raise ValueError("Info log keys must contain the text 'type'")
await self.conn.execute(
f"INSERT INTO {collection} (timestamp, log_id, log_type) VALUES (datetime('now'), ?, ?)",
(str(log_id), value),
)
else:
await self.conn.execute(
f"INSERT INTO {collection} (timestamp, log_id, key, value) VALUES (datetime('now'), ?, ?, ?)",
(str(log_id), key, value),
)
await self.conn.commit()
async def get_run_info(
self, limit: int = 10, log_type_filter: Optional[str] = None
) -> list[RunInfo]:
cursor = await self.conn.cursor()
query = f'SELECT log_id, log_type FROM "{self.log_info_table}"'
conditions = []
params = []
if log_type_filter:
conditions.append("log_type = ?")
params.append(log_type_filter)
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY timestamp DESC LIMIT ?"
params.append(limit)
await cursor.execute(query, params)
rows = await cursor.fetchall()
return [
RunInfo(run_id=uuid.UUID(row[0]), log_type=row[1]) for row in rows
]
async def get_logs(
self, run_ids: list[uuid.UUID], limit_per_run: int = 10
) -> list:
if not run_ids:
raise ValueError("No run ids provided.")
cursor = await self.conn.cursor()
placeholders = ",".join(["?" for _ in run_ids])
query = f"""
SELECT *
FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn
FROM {self.log_table}
WHERE log_id IN ({placeholders})
)
WHERE rn <= ?
ORDER BY timestamp DESC
"""
params = [str(ele) for ele in run_ids] + [limit_per_run]
await cursor.execute(query, params)
rows = await cursor.fetchall()
new_rows = []
for row in rows:
new_rows.append(
(row[0], uuid.UUID(row[1]), row[2], row[3], row[4])
)
return [
{desc[0]: row[i] for i, desc in enumerate(cursor.description)}
for row in new_rows
]
class PostgresLoggingConfig(LoggingConfig):
provider: str = "postgres"
log_table: str = "logs"
log_info_table: str = "logs_pipeline_info"
def validate(self) -> None:
required_env_vars = [
"POSTGRES_DBNAME",
"POSTGRES_USER",
"POSTGRES_PASSWORD",
"POSTGRES_HOST",
"POSTGRES_PORT",
]
for var in required_env_vars:
if not os.getenv(var):
raise ValueError(f"Environment variable {var} is not set.")
@property
def supported_providers(self) -> list[str]:
return ["postgres"]
class PostgresKVLoggingProvider(KVLoggingProvider):
def __init__(self, config: PostgresLoggingConfig):
self.log_table = config.log_table
self.log_info_table = config.log_info_table
self.config = config
self.pool = None
if not os.getenv("POSTGRES_DBNAME"):
raise ValueError(
"Please set the environment variable POSTGRES_DBNAME."
)
if not os.getenv("POSTGRES_USER"):
raise ValueError(
"Please set the environment variable POSTGRES_USER."
)
if not os.getenv("POSTGRES_PASSWORD"):
raise ValueError(
"Please set the environment variable POSTGRES_PASSWORD."
)
if not os.getenv("POSTGRES_HOST"):
raise ValueError(
"Please set the environment variable POSTGRES_HOST."
)
if not os.getenv("POSTGRES_PORT"):
raise ValueError(
"Please set the environment variable POSTGRES_PORT."
)
async def init(self):
self.pool = await asyncpg.create_pool(
database=os.getenv("POSTGRES_DBNAME"),
user=os.getenv("POSTGRES_USER"),
password=os.getenv("POSTGRES_PASSWORD"),
host=os.getenv("POSTGRES_HOST"),
port=os.getenv("POSTGRES_PORT"),
statement_cache_size=0, # Disable statement caching
)
async with self.pool.acquire() as conn:
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS "{self.log_table}" (
timestamp TIMESTAMPTZ,
log_id UUID,
key TEXT,
value TEXT
)
"""
)
await conn.execute(
f"""
CREATE TABLE IF NOT EXISTS "{self.log_info_table}" (
timestamp TIMESTAMPTZ,
log_id UUID UNIQUE,
log_type TEXT
)
"""
)
async def __aenter__(self):
if self.pool is None:
await self.init()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
await self.close()
async def close(self):
if self.pool:
await self.pool.close()
self.pool = None
async def log(
self,
log_id: uuid.UUID,
key: str,
value: str,
is_info_log=False,
):
collection = self.log_info_table if is_info_log else self.log_table
if is_info_log:
if "type" not in key:
raise ValueError(
"Info log key must contain the string `type`."
)
async with self.pool.acquire() as conn:
await self.pool.execute(
f'INSERT INTO "{collection}" (timestamp, log_id, log_type) VALUES (NOW(), $1, $2)',
log_id,
value,
)
else:
async with self.pool.acquire() as conn:
await conn.execute(
f'INSERT INTO "{collection}" (timestamp, log_id, key, value) VALUES (NOW(), $1, $2, $3)',
log_id,
key,
value,
)
async def get_run_info(
self, limit: int = 10, log_type_filter: Optional[str] = None
) -> list[RunInfo]:
query = f"SELECT log_id, log_type FROM {self.log_info_table}"
conditions = []
params = []
if log_type_filter:
conditions.append("log_type = $1")
params.append(log_type_filter)
if conditions:
query += " WHERE " + " AND ".join(conditions)
query += " ORDER BY timestamp DESC LIMIT $2"
params.append(limit)
async with self.pool.acquire() as conn:
rows = await conn.fetch(query, *params)
return [
RunInfo(run_id=row["log_id"], log_type=row["log_type"])
for row in rows
]
async def get_logs(
self, run_ids: list[uuid.UUID], limit_per_run: int = 10
) -> list:
if not run_ids:
raise ValueError("No run ids provided.")
placeholders = ",".join([f"${i + 1}" for i in range(len(run_ids))])
query = f"""
SELECT * FROM (
SELECT *, ROW_NUMBER() OVER (PARTITION BY log_id ORDER BY timestamp DESC) as rn
FROM "{self.log_table}"
WHERE log_id::text IN ({placeholders})
) sub
WHERE sub.rn <= ${len(run_ids) + 1}
ORDER BY sub.timestamp DESC
"""
params = [str(run_id) for run_id in run_ids] + [limit_per_run]
async with self.pool.acquire() as conn:
rows = await conn.fetch(query, *params)
return [{key: row[key] for key in row.keys()} for row in rows]
class RedisLoggingConfig(LoggingConfig):
provider: str = "redis"
log_table: str = "logs"
log_info_table: str = "logs_pipeline_info"
def validate(self) -> None:
required_env_vars = ["REDIS_CLUSTER_IP", "REDIS_CLUSTER_PORT"]
for var in required_env_vars:
if not os.getenv(var):
raise ValueError(f"Environment variable {var} is not set.")
@property
def supported_providers(self) -> list[str]:
return ["redis"]
class RedisKVLoggingProvider(KVLoggingProvider):
def __init__(self, config: RedisLoggingConfig):
logger.info(
f"Initializing RedisKVLoggingProvider with config: {config}"
)
if not all(
[
os.getenv("REDIS_CLUSTER_IP"),
os.getenv("REDIS_CLUSTER_PORT"),
]
):
raise ValueError(
"Please set the environment variables REDIS_CLUSTER_IP and REDIS_CLUSTER_PORT to run `LoggingDatabaseConnection` with `redis`."
)
try:
from redis.asyncio import Redis
except ImportError:
raise ValueError(
"Error, `redis` is not installed. Please install it using `pip install redis`."
)
cluster_ip = os.getenv("REDIS_CLUSTER_IP")
port = os.getenv("REDIS_CLUSTER_PORT")
self.redis = Redis(host=cluster_ip, port=port, decode_responses=True)
self.log_key = config.log_table
self.log_info_key = config.log_info_table
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
await self.close()
async def close(self):
await self.redis.close()
async def log(
self,
log_id: uuid.UUID,
key: str,
value: str,
is_info_log=False,
):
timestamp = datetime.now().timestamp()
log_entry = {
"timestamp": timestamp,
"log_id": str(log_id),
"key": key,
"value": value,
}
if is_info_log:
if "type" not in key:
raise ValueError("Metadata keys must contain the text 'type'")
log_entry["log_type"] = value
await self.redis.hset(
self.log_info_key, str(log_id), json.dumps(log_entry)
)
await self.redis.zadd(
f"{self.log_info_key}_sorted", {str(log_id): timestamp}
)
else:
await self.redis.lpush(
f"{self.log_key}:{str(log_id)}", json.dumps(log_entry)
)
async def get_run_info(
self, limit: int = 10, log_type_filter: Optional[str] = None
) -> list[RunInfo]:
run_info_list = []
start = 0
count_per_batch = 100 # Adjust batch size as needed
while len(run_info_list) < limit:
log_ids = await self.redis.zrevrange(
f"{self.log_info_key}_sorted",
start,
start + count_per_batch - 1,
)
if not log_ids:
break # No more log IDs to process
start += count_per_batch
for log_id in log_ids:
log_entry = json.loads(
await self.redis.hget(self.log_info_key, log_id)
)
if log_type_filter:
if log_entry["log_type"] == log_type_filter:
run_info_list.append(
RunInfo(
run_id=uuid.UUID(log_entry["log_id"]),
log_type=log_entry["log_type"],
)
)
else:
run_info_list.append(
RunInfo(
run_id=uuid.UUID(log_entry["log_id"]),
log_type=log_entry["log_type"],
)
)
if len(run_info_list) >= limit:
break
return run_info_list[:limit]
async def get_logs(
self, run_ids: list[uuid.UUID], limit_per_run: int = 10
) -> list:
logs = []
for run_id in run_ids:
raw_logs = await self.redis.lrange(
f"{self.log_key}:{str(run_id)}", 0, limit_per_run - 1
)
for raw_log in raw_logs:
json_log = json.loads(raw_log)
json_log["log_id"] = uuid.UUID(json_log["log_id"])
logs.append(json_log)
return logs
class KVLoggingSingleton:
_instance = None
_is_configured = False
SUPPORTED_PROVIDERS = {
"local": LocalKVLoggingProvider,
"postgres": PostgresKVLoggingProvider,
"redis": RedisKVLoggingProvider,
}
@classmethod
def get_instance(cls):
return cls.SUPPORTED_PROVIDERS[cls._config.provider](cls._config)
@classmethod
def configure(
cls, logging_config: Optional[LoggingConfig] = LoggingConfig()
):
if not cls._is_configured:
cls._config = logging_config
cls._is_configured = True
else:
raise Exception("KVLoggingSingleton is already configured.")
@classmethod
async def log(
cls,
log_id: uuid.UUID,
key: str,
value: str,
is_info_log=False,
):
try:
async with cls.get_instance() as provider:
await provider.log(log_id, key, value, is_info_log=is_info_log)
except Exception as e:
logger.error(f"Error logging data {(log_id, key, value)}: {e}")
@classmethod
async def get_run_info(
cls, limit: int = 10, log_type_filter: Optional[str] = None
) -> list[RunInfo]:
async with cls.get_instance() as provider:
return await provider.get_run_info(
limit, log_type_filter=log_type_filter
)
@classmethod
async def get_logs(
cls, run_ids: list[uuid.UUID], limit_per_run: int = 10
) -> list:
async with cls.get_instance() as provider:
return await provider.get_logs(run_ids, limit_per_run)