aboutsummaryrefslogtreecommitdiff
path: root/.venv/lib/python3.12/site-packages/core/providers/database/limits.py
diff options
context:
space:
mode:
authorS. Solomon Darnell2025-03-28 21:52:21 -0500
committerS. Solomon Darnell2025-03-28 21:52:21 -0500
commit4a52a71956a8d46fcb7294ac71734504bb09bcc2 (patch)
treeee3dc5af3b6313e921cd920906356f5d4febc4ed /.venv/lib/python3.12/site-packages/core/providers/database/limits.py
parentcc961e04ba734dd72309fb548a2f97d67d578813 (diff)
downloadgn-ai-master.tar.gz
two version of R2R are hereHEADmaster
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/limits.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/limits.py434
1 files changed, 434 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/limits.py b/.venv/lib/python3.12/site-packages/core/providers/database/limits.py
new file mode 100644
index 00000000..1029ec50
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/limits.py
@@ -0,0 +1,434 @@
+import logging
+from datetime import datetime, timedelta, timezone
+from typing import Optional
+from uuid import UUID
+
+from core.base import Handler
+from shared.abstractions import User
+
+from ...base.providers.database import DatabaseConfig, LimitSettings
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger(__name__)
+
+
+class PostgresLimitsHandler(Handler):
+ TABLE_NAME = "request_log"
+
+ def __init__(
+ self,
+ project_name: str,
+ connection_manager: PostgresConnectionManager,
+ config: DatabaseConfig,
+ ):
+ """
+ :param config: The global DatabaseConfig with default rate limits.
+ """
+ super().__init__(project_name, connection_manager)
+ self.config = config
+
+ logger.debug(
+ f"Initialized PostgresLimitsHandler with project: {project_name}"
+ )
+
+ async def create_tables(self):
+ query = f"""
+ CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
+ time TIMESTAMPTZ NOT NULL,
+ user_id UUID NOT NULL,
+ route TEXT NOT NULL
+ );
+ """
+ logger.debug("Creating request_log table if not exists")
+ await self.connection_manager.execute_query(query)
+
+ async def _count_requests(
+ self,
+ user_id: UUID,
+ route: Optional[str],
+ since: datetime,
+ ) -> int:
+ """Count how many requests a user (optionally for a specific route) has
+ made since the given datetime."""
+ if route:
+ query = f"""
+ SELECT COUNT(*)::int
+ FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+ WHERE user_id = $1
+ AND route = $2
+ AND time >= $3
+ """
+ params = [user_id, route, since]
+ logger.debug(
+ f"Counting requests for user={user_id}, route={route}"
+ )
+ else:
+ query = f"""
+ SELECT COUNT(*)::int
+ FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+ WHERE user_id = $1
+ AND time >= $2
+ """
+ params = [user_id, since]
+ logger.debug(f"Counting all requests for user={user_id}")
+
+ result = await self.connection_manager.fetchrow_query(query, params)
+ return result["count"] if result else 0
+
+ async def _count_monthly_requests(
+ self,
+ user_id: UUID,
+ route: Optional[str] = None, # <--- ADDED THIS
+ ) -> int:
+ """Count the number of requests so far this month for a given user.
+
+ If route is provided, count only for that route. Otherwise, count
+ globally.
+ """
+ now = datetime.now(timezone.utc)
+ start_of_month = now.replace(
+ day=1, hour=0, minute=0, second=0, microsecond=0
+ )
+ return await self._count_requests(
+ user_id, route=route, since=start_of_month
+ )
+
+ def determine_effective_limits(
+ self, user: User, route: str
+ ) -> LimitSettings:
+ """
+ Determine the final effective limits for a user+route combination,
+ respecting:
+ 1) Global defaults
+ 2) Route-specific overrides
+ 3) User-level overrides
+ """
+ # ------------------------
+ # 1) Start with global/base
+ # ------------------------
+ base_limits = self.config.limits
+
+ # We’ll make a copy so we don’t mutate self.config.limits directly
+ effective = LimitSettings(
+ global_per_min=base_limits.global_per_min,
+ route_per_min=base_limits.route_per_min,
+ monthly_limit=base_limits.monthly_limit,
+ )
+
+ # ------------------------
+ # 2) Route-level overrides
+ # ------------------------
+ route_config = self.config.route_limits.get(route)
+ if route_config:
+ if route_config.global_per_min is not None:
+ effective.global_per_min = route_config.global_per_min
+ if route_config.route_per_min is not None:
+ effective.route_per_min = route_config.route_per_min
+ if route_config.monthly_limit is not None:
+ effective.monthly_limit = route_config.monthly_limit
+
+ # ------------------------
+ # 3) User-level overrides
+ # ------------------------
+ # The user object might have a dictionary of overrides
+ # which can include route_overrides, global_per_min, monthly_limit, etc.
+ user_overrides = user.limits_overrides or {}
+
+ # (a) "global" user overrides
+ if user_overrides.get("global_per_min") is not None:
+ effective.global_per_min = user_overrides["global_per_min"]
+ if user_overrides.get("monthly_limit") is not None:
+ effective.monthly_limit = user_overrides["monthly_limit"]
+
+ # (b) route-level user overrides
+ route_overrides = user_overrides.get("route_overrides", {})
+ specific_config = route_overrides.get(route, {})
+ if specific_config.get("global_per_min") is not None:
+ effective.global_per_min = specific_config["global_per_min"]
+ if specific_config.get("route_per_min") is not None:
+ effective.route_per_min = specific_config["route_per_min"]
+ if specific_config.get("monthly_limit") is not None:
+ effective.monthly_limit = specific_config["monthly_limit"]
+
+ return effective
+
+ async def check_limits(self, user: User, route: str):
+ """Perform rate limit checks for a user on a specific route.
+
+ :param user: The fully-fetched User object with .limits_overrides, etc.
+ :param route: The route/path being accessed.
+ :raises ValueError: if any limit is exceeded.
+ """
+ user_id = user.id
+ now = datetime.now(timezone.utc)
+ one_min_ago = now - timedelta(minutes=1)
+
+ # 1) Compute the final (effective) limits for this user & route
+ limits = self.determine_effective_limits(user, route)
+
+ # 2) Check each of them in turn, if they exist
+ # ------------------------------------------------------------
+ # Global per-minute limit
+ # ------------------------------------------------------------
+ if limits.global_per_min is not None:
+ user_req_count = await self._count_requests(
+ user_id, None, one_min_ago
+ )
+ if user_req_count > limits.global_per_min:
+ logger.warning(
+ f"Global per-minute limit exceeded for "
+ f"user_id={user_id}, route={route}"
+ )
+ raise ValueError("Global per-minute rate limit exceeded")
+
+ # ------------------------------------------------------------
+ # Route-specific per-minute limit
+ # ------------------------------------------------------------
+ if limits.route_per_min is not None:
+ route_req_count = await self._count_requests(
+ user_id, route, one_min_ago
+ )
+ if route_req_count > limits.route_per_min:
+ logger.warning(
+ f"Per-route per-minute limit exceeded for "
+ f"user_id={user_id}, route={route}"
+ )
+ raise ValueError("Per-route per-minute rate limit exceeded")
+
+ # ------------------------------------------------------------
+ # Monthly limit
+ # ------------------------------------------------------------
+ if limits.monthly_limit is not None:
+ # If you truly want a per-route monthly limit, we pass 'route'.
+ # If you want a global monthly limit, pass 'None'.
+ monthly_count = await self._count_monthly_requests(user_id, route)
+ if monthly_count > limits.monthly_limit:
+ logger.warning(
+ f"Monthly limit exceeded for user_id={user_id}, "
+ f"route={route}"
+ )
+ raise ValueError("Monthly rate limit exceeded")
+
+ async def log_request(self, user_id: UUID, route: str):
+ """Log a successful request to the request_log table."""
+ query = f"""
+ INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+ (time, user_id, route)
+ VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
+ """
+ await self.connection_manager.execute_query(query, [user_id, route])
+
+
+# import logging
+# from datetime import datetime, timedelta, timezone
+# from typing import Optional
+# from uuid import UUID
+
+# from core.base import Handler
+# from shared.abstractions import User
+
+# from ..base.providers.database import DatabaseConfig, LimitSettings
+# from .base import PostgresConnectionManager
+
+# logger = logging.getLogger(__name__)
+
+# class PostgresLimitsHandler(Handler):
+# TABLE_NAME = "request_log"
+
+# def __init__(
+# self,
+# project_name: str,
+# connection_manager: PostgresConnectionManager,
+# config: DatabaseConfig,
+# ):
+# """
+# :param config: The global DatabaseConfig with default rate limits.
+# """
+# super().__init__(project_name, connection_manager)
+# self.config = config
+
+# logger.debug(
+# f"Initialized PostgresLimitsHandler with project: {project_name}"
+# )
+
+# async def create_tables(self):
+# query = f"""
+# CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
+# time TIMESTAMPTZ NOT NULL,
+# user_id UUID NOT NULL,
+# route TEXT NOT NULL
+# );
+# """
+# logger.debug("Creating request_log table if not exists")
+# await self.connection_manager.execute_query(query)
+
+# async def _count_requests(
+# self,
+# user_id: UUID,
+# route: Optional[str],
+# since: datetime,
+# ) -> int:
+# """
+# Count how many requests a user (optionally for a specific route)
+# has made since the given datetime.
+# """
+# if route:
+# query = f"""
+# SELECT COUNT(*)::int
+# FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+# WHERE user_id = $1
+# AND route = $2
+# AND time >= $3
+# """
+# params = [user_id, route, since]
+# logger.debug(f"Counting requests for user={user_id}, route={route}")
+# else:
+# query = f"""
+# SELECT COUNT(*)::int
+# FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+# WHERE user_id = $1
+# AND time >= $2
+# """
+# params = [user_id, since]
+# logger.debug(f"Counting all requests for user={user_id}")
+
+# result = await self.connection_manager.fetchrow_query(query, params)
+# return result["count"] if result else 0
+
+# async def _count_monthly_requests(self, user_id: UUID) -> int:
+# """
+# Count the number of requests so far this month for a given user.
+# """
+# now = datetime.now(timezone.utc)
+# start_of_month = now.replace(
+# day=1, hour=0, minute=0, second=0, microsecond=0
+# )
+# return await self._count_requests(
+# user_id, route=None, since=start_of_month
+# )
+
+# def determine_effective_limits(
+# self, user: User, route: str
+# ) -> LimitSettings:
+# """
+# Determine the final effective limits for a user+route combination,
+# respecting:
+# 1) Global defaults
+# 2) Route-specific overrides
+# 3) User-level overrides
+# """
+# # ------------------------
+# # 1) Start with global/base
+# # ------------------------
+# base_limits = self.config.limits
+
+# # We’ll make a copy so we don’t mutate self.config.limits directly
+# effective = LimitSettings(
+# global_per_min=base_limits.global_per_min,
+# route_per_min=base_limits.route_per_min,
+# monthly_limit=base_limits.monthly_limit,
+# )
+
+# # ------------------------
+# # 2) Route-level overrides
+# # ------------------------
+# route_config = self.config.route_limits.get(route)
+# if route_config:
+# if route_config.global_per_min is not None:
+# effective.global_per_min = route_config.global_per_min
+# if route_config.route_per_min is not None:
+# effective.route_per_min = route_config.route_per_min
+# if route_config.monthly_limit is not None:
+# effective.monthly_limit = route_config.monthly_limit
+
+# # ------------------------
+# # 3) User-level overrides
+# # ------------------------
+# # The user object might have a dictionary of overrides
+# # which can include route_overrides, global_per_min, monthly_limit, etc.
+# user_overrides = user.limits_overrides or {}
+
+# # (a) "global" user overrides
+# if user_overrides.get("global_per_min") is not None:
+# effective.global_per_min = user_overrides["global_per_min"]
+# if user_overrides.get("monthly_limit") is not None:
+# effective.monthly_limit = user_overrides["monthly_limit"]
+
+# # (b) route-level user overrides
+# route_overrides = user_overrides.get("route_overrides", {})
+# specific_config = route_overrides.get(route, {})
+# if specific_config.get("global_per_min") is not None:
+# effective.global_per_min = specific_config["global_per_min"]
+# if specific_config.get("route_per_min") is not None:
+# effective.route_per_min = specific_config["route_per_min"]
+# if specific_config.get("monthly_limit") is not None:
+# effective.monthly_limit = specific_config["monthly_limit"]
+
+# return effective
+
+# async def check_limits(self, user: User, route: str):
+# """
+# Perform rate limit checks for a user on a specific route.
+
+# :param user: The fully-fetched User object with .limits_overrides, etc.
+# :param route: The route/path being accessed.
+# :raises ValueError: if any limit is exceeded.
+# """
+# user_id = user.id
+# now = datetime.now(timezone.utc)
+# one_min_ago = now - timedelta(minutes=1)
+
+# # 1) Compute the final (effective) limits for this user & route
+# limits = self.determine_effective_limits(user, route)
+
+# # 2) Check each of them in turn, if they exist
+# # ------------------------------------------------------------
+# # Global per-minute limit
+# # ------------------------------------------------------------
+# if limits.global_per_min is not None:
+# user_req_count = await self._count_requests(
+# user_id, None, one_min_ago
+# )
+# if user_req_count > limits.global_per_min:
+# logger.warning(
+# f"Global per-minute limit exceeded for "
+# f"user_id={user_id}, route={route}"
+# )
+# raise ValueError("Global per-minute rate limit exceeded")
+
+# # ------------------------------------------------------------
+# # Route-specific per-minute limit
+# # ------------------------------------------------------------
+# if limits.route_per_min is not None:
+# route_req_count = await self._count_requests(
+# user_id, route, one_min_ago
+# )
+# if route_req_count > limits.route_per_min:
+# logger.warning(
+# f"Per-route per-minute limit exceeded for "
+# f"user_id={user_id}, route={route}"
+# )
+# raise ValueError("Per-route per-minute rate limit exceeded")
+
+# # ------------------------------------------------------------
+# # Monthly limit
+# # ------------------------------------------------------------
+# if limits.monthly_limit is not None:
+# monthly_count = await self._count_monthly_requests(user_id)
+# if monthly_count > limits.monthly_limit:
+# logger.warning(
+# f"Monthly limit exceeded for user_id={user_id}, "
+# f"route={route}"
+# )
+# raise ValueError("Monthly rate limit exceeded")
+
+# async def log_request(self, user_id: UUID, route: str):
+# """
+# Log a successful request to the request_log table.
+# """
+# query = f"""
+# INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+# (time, user_id, route)
+# VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
+# """
+# await self.connection_manager.execute_query(query, [user_id, route])