about summary refs log tree commit diff
path: root/.venv/lib/python3.12/site-packages/core/providers/database/limits.py
diff options
context:
space:
mode:
Diffstat (limited to '.venv/lib/python3.12/site-packages/core/providers/database/limits.py')
-rw-r--r--.venv/lib/python3.12/site-packages/core/providers/database/limits.py434
1 files changed, 434 insertions, 0 deletions
diff --git a/.venv/lib/python3.12/site-packages/core/providers/database/limits.py b/.venv/lib/python3.12/site-packages/core/providers/database/limits.py
new file mode 100644
index 00000000..1029ec50
--- /dev/null
+++ b/.venv/lib/python3.12/site-packages/core/providers/database/limits.py
@@ -0,0 +1,434 @@
+import logging
+from datetime import datetime, timedelta, timezone
+from typing import Optional
+from uuid import UUID
+
+from core.base import Handler
+from shared.abstractions import User
+
+from ...base.providers.database import DatabaseConfig, LimitSettings
+from .base import PostgresConnectionManager
+
+logger = logging.getLogger(__name__)
+
+
+class PostgresLimitsHandler(Handler):
+    TABLE_NAME = "request_log"
+
+    def __init__(
+        self,
+        project_name: str,
+        connection_manager: PostgresConnectionManager,
+        config: DatabaseConfig,
+    ):
+        """
+        :param config: The global DatabaseConfig with default rate limits.
+        """
+        super().__init__(project_name, connection_manager)
+        self.config = config
+
+        logger.debug(
+            f"Initialized PostgresLimitsHandler with project: {project_name}"
+        )
+
+    async def create_tables(self):
+        query = f"""
+        CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
+            time TIMESTAMPTZ NOT NULL,
+            user_id UUID NOT NULL,
+            route TEXT NOT NULL
+        );
+        """
+        logger.debug("Creating request_log table if not exists")
+        await self.connection_manager.execute_query(query)
+
+    async def _count_requests(
+        self,
+        user_id: UUID,
+        route: Optional[str],
+        since: datetime,
+    ) -> int:
+        """Count how many requests a user (optionally for a specific route) has
+        made since the given datetime."""
+        if route:
+            query = f"""
+            SELECT COUNT(*)::int
+            FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+            WHERE user_id = $1
+              AND route = $2
+              AND time >= $3
+            """
+            params = [user_id, route, since]
+            logger.debug(
+                f"Counting requests for user={user_id}, route={route}"
+            )
+        else:
+            query = f"""
+            SELECT COUNT(*)::int
+            FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+            WHERE user_id = $1
+              AND time >= $2
+            """
+            params = [user_id, since]
+            logger.debug(f"Counting all requests for user={user_id}")
+
+        result = await self.connection_manager.fetchrow_query(query, params)
+        return result["count"] if result else 0
+
+    async def _count_monthly_requests(
+        self,
+        user_id: UUID,
+        route: Optional[str] = None,  # <--- ADDED THIS
+    ) -> int:
+        """Count the number of requests so far this month for a given user.
+
+        If route is provided, count only for that route. Otherwise, count
+        globally.
+        """
+        now = datetime.now(timezone.utc)
+        start_of_month = now.replace(
+            day=1, hour=0, minute=0, second=0, microsecond=0
+        )
+        return await self._count_requests(
+            user_id, route=route, since=start_of_month
+        )
+
+    def determine_effective_limits(
+        self, user: User, route: str
+    ) -> LimitSettings:
+        """
+        Determine the final effective limits for a user+route combination,
+        respecting:
+          1) Global defaults
+          2) Route-specific overrides
+          3) User-level overrides
+        """
+        # ------------------------
+        # 1) Start with global/base
+        # ------------------------
+        base_limits = self.config.limits
+
+        # We’ll make a copy so we don’t mutate self.config.limits directly
+        effective = LimitSettings(
+            global_per_min=base_limits.global_per_min,
+            route_per_min=base_limits.route_per_min,
+            monthly_limit=base_limits.monthly_limit,
+        )
+
+        # ------------------------
+        # 2) Route-level overrides
+        # ------------------------
+        route_config = self.config.route_limits.get(route)
+        if route_config:
+            if route_config.global_per_min is not None:
+                effective.global_per_min = route_config.global_per_min
+            if route_config.route_per_min is not None:
+                effective.route_per_min = route_config.route_per_min
+            if route_config.monthly_limit is not None:
+                effective.monthly_limit = route_config.monthly_limit
+
+        # ------------------------
+        # 3) User-level overrides
+        # ------------------------
+        # The user object might have a dictionary of overrides
+        # which can include route_overrides, global_per_min, monthly_limit, etc.
+        user_overrides = user.limits_overrides or {}
+
+        # (a) "global" user overrides
+        if user_overrides.get("global_per_min") is not None:
+            effective.global_per_min = user_overrides["global_per_min"]
+        if user_overrides.get("monthly_limit") is not None:
+            effective.monthly_limit = user_overrides["monthly_limit"]
+
+        # (b) route-level user overrides
+        route_overrides = user_overrides.get("route_overrides", {})
+        specific_config = route_overrides.get(route, {})
+        if specific_config.get("global_per_min") is not None:
+            effective.global_per_min = specific_config["global_per_min"]
+        if specific_config.get("route_per_min") is not None:
+            effective.route_per_min = specific_config["route_per_min"]
+        if specific_config.get("monthly_limit") is not None:
+            effective.monthly_limit = specific_config["monthly_limit"]
+
+        return effective
+
+    async def check_limits(self, user: User, route: str):
+        """Perform rate limit checks for a user on a specific route.
+
+        :param user: The fully-fetched User object with .limits_overrides, etc.
+        :param route: The route/path being accessed.
+        :raises ValueError: if any limit is exceeded.
+        """
+        user_id = user.id
+        now = datetime.now(timezone.utc)
+        one_min_ago = now - timedelta(minutes=1)
+
+        # 1) Compute the final (effective) limits for this user & route
+        limits = self.determine_effective_limits(user, route)
+
+        # 2) Check each of them in turn, if they exist
+        # ------------------------------------------------------------
+        # Global per-minute limit
+        # ------------------------------------------------------------
+        if limits.global_per_min is not None:
+            user_req_count = await self._count_requests(
+                user_id, None, one_min_ago
+            )
+            if user_req_count > limits.global_per_min:
+                logger.warning(
+                    f"Global per-minute limit exceeded for "
+                    f"user_id={user_id}, route={route}"
+                )
+                raise ValueError("Global per-minute rate limit exceeded")
+
+        # ------------------------------------------------------------
+        # Route-specific per-minute limit
+        # ------------------------------------------------------------
+        if limits.route_per_min is not None:
+            route_req_count = await self._count_requests(
+                user_id, route, one_min_ago
+            )
+            if route_req_count > limits.route_per_min:
+                logger.warning(
+                    f"Per-route per-minute limit exceeded for "
+                    f"user_id={user_id}, route={route}"
+                )
+                raise ValueError("Per-route per-minute rate limit exceeded")
+
+        # ------------------------------------------------------------
+        # Monthly limit
+        # ------------------------------------------------------------
+        if limits.monthly_limit is not None:
+            # If you truly want a per-route monthly limit, we pass 'route'.
+            # If you want a global monthly limit, pass 'None'.
+            monthly_count = await self._count_monthly_requests(user_id, route)
+            if monthly_count > limits.monthly_limit:
+                logger.warning(
+                    f"Monthly limit exceeded for user_id={user_id}, "
+                    f"route={route}"
+                )
+                raise ValueError("Monthly rate limit exceeded")
+
+    async def log_request(self, user_id: UUID, route: str):
+        """Log a successful request to the request_log table."""
+        query = f"""
+        INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+        (time, user_id, route)
+        VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
+        """
+        await self.connection_manager.execute_query(query, [user_id, route])
+
+
+# import logging
+# from datetime import datetime, timedelta, timezone
+# from typing import Optional
+# from uuid import UUID
+
+# from core.base import Handler
+# from shared.abstractions import User
+
+# from ..base.providers.database import DatabaseConfig, LimitSettings
+# from .base import PostgresConnectionManager
+
+# logger = logging.getLogger(__name__)
+
+# class PostgresLimitsHandler(Handler):
+#     TABLE_NAME = "request_log"
+
+#     def __init__(
+#         self,
+#         project_name: str,
+#         connection_manager: PostgresConnectionManager,
+#         config: DatabaseConfig,
+#     ):
+#         """
+#         :param config: The global DatabaseConfig with default rate limits.
+#         """
+#         super().__init__(project_name, connection_manager)
+#         self.config = config
+
+#         logger.debug(
+#             f"Initialized PostgresLimitsHandler with project: {project_name}"
+#         )
+
+#     async def create_tables(self):
+#         query = f"""
+#         CREATE TABLE IF NOT EXISTS {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)} (
+#             time TIMESTAMPTZ NOT NULL,
+#             user_id UUID NOT NULL,
+#             route TEXT NOT NULL
+#         );
+#         """
+#         logger.debug("Creating request_log table if not exists")
+#         await self.connection_manager.execute_query(query)
+
+#     async def _count_requests(
+#         self,
+#         user_id: UUID,
+#         route: Optional[str],
+#         since: datetime,
+#     ) -> int:
+#         """
+#         Count how many requests a user (optionally for a specific route)
+#         has made since the given datetime.
+#         """
+#         if route:
+#             query = f"""
+#             SELECT COUNT(*)::int
+#             FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+#             WHERE user_id = $1
+#               AND route = $2
+#               AND time >= $3
+#             """
+#             params = [user_id, route, since]
+#             logger.debug(f"Counting requests for user={user_id}, route={route}")
+#         else:
+#             query = f"""
+#             SELECT COUNT(*)::int
+#             FROM {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+#             WHERE user_id = $1
+#               AND time >= $2
+#             """
+#             params = [user_id, since]
+#             logger.debug(f"Counting all requests for user={user_id}")
+
+#         result = await self.connection_manager.fetchrow_query(query, params)
+#         return result["count"] if result else 0
+
+#     async def _count_monthly_requests(self, user_id: UUID) -> int:
+#         """
+#         Count the number of requests so far this month for a given user.
+#         """
+#         now = datetime.now(timezone.utc)
+#         start_of_month = now.replace(
+#             day=1, hour=0, minute=0, second=0, microsecond=0
+#         )
+#         return await self._count_requests(
+#             user_id, route=None, since=start_of_month
+#         )
+
+#     def determine_effective_limits(
+#         self, user: User, route: str
+#     ) -> LimitSettings:
+#         """
+#         Determine the final effective limits for a user+route combination,
+#         respecting:
+#           1) Global defaults
+#           2) Route-specific overrides
+#           3) User-level overrides
+#         """
+#         # ------------------------
+#         # 1) Start with global/base
+#         # ------------------------
+#         base_limits = self.config.limits
+
+#         # We’ll make a copy so we don’t mutate self.config.limits directly
+#         effective = LimitSettings(
+#             global_per_min=base_limits.global_per_min,
+#             route_per_min=base_limits.route_per_min,
+#             monthly_limit=base_limits.monthly_limit,
+#         )
+
+#         # ------------------------
+#         # 2) Route-level overrides
+#         # ------------------------
+#         route_config = self.config.route_limits.get(route)
+#         if route_config:
+#             if route_config.global_per_min is not None:
+#                 effective.global_per_min = route_config.global_per_min
+#             if route_config.route_per_min is not None:
+#                 effective.route_per_min = route_config.route_per_min
+#             if route_config.monthly_limit is not None:
+#                 effective.monthly_limit = route_config.monthly_limit
+
+#         # ------------------------
+#         # 3) User-level overrides
+#         # ------------------------
+#         # The user object might have a dictionary of overrides
+#         # which can include route_overrides, global_per_min, monthly_limit, etc.
+#         user_overrides = user.limits_overrides or {}
+
+#         # (a) "global" user overrides
+#         if user_overrides.get("global_per_min") is not None:
+#             effective.global_per_min = user_overrides["global_per_min"]
+#         if user_overrides.get("monthly_limit") is not None:
+#             effective.monthly_limit = user_overrides["monthly_limit"]
+
+#         # (b) route-level user overrides
+#         route_overrides = user_overrides.get("route_overrides", {})
+#         specific_config = route_overrides.get(route, {})
+#         if specific_config.get("global_per_min") is not None:
+#             effective.global_per_min = specific_config["global_per_min"]
+#         if specific_config.get("route_per_min") is not None:
+#             effective.route_per_min = specific_config["route_per_min"]
+#         if specific_config.get("monthly_limit") is not None:
+#             effective.monthly_limit = specific_config["monthly_limit"]
+
+#         return effective
+
+#     async def check_limits(self, user: User, route: str):
+#         """
+#         Perform rate limit checks for a user on a specific route.
+
+#         :param user: The fully-fetched User object with .limits_overrides, etc.
+#         :param route: The route/path being accessed.
+#         :raises ValueError: if any limit is exceeded.
+#         """
+#         user_id = user.id
+#         now = datetime.now(timezone.utc)
+#         one_min_ago = now - timedelta(minutes=1)
+
+#         # 1) Compute the final (effective) limits for this user & route
+#         limits = self.determine_effective_limits(user, route)
+
+#         # 2) Check each of them in turn, if they exist
+#         # ------------------------------------------------------------
+#         # Global per-minute limit
+#         # ------------------------------------------------------------
+#         if limits.global_per_min is not None:
+#             user_req_count = await self._count_requests(
+#                 user_id, None, one_min_ago
+#             )
+#             if user_req_count > limits.global_per_min:
+#                 logger.warning(
+#                     f"Global per-minute limit exceeded for "
+#                     f"user_id={user_id}, route={route}"
+#                 )
+#                 raise ValueError("Global per-minute rate limit exceeded")
+
+#         # ------------------------------------------------------------
+#         # Route-specific per-minute limit
+#         # ------------------------------------------------------------
+#         if limits.route_per_min is not None:
+#             route_req_count = await self._count_requests(
+#                 user_id, route, one_min_ago
+#             )
+#             if route_req_count > limits.route_per_min:
+#                 logger.warning(
+#                     f"Per-route per-minute limit exceeded for "
+#                     f"user_id={user_id}, route={route}"
+#                 )
+#                 raise ValueError("Per-route per-minute rate limit exceeded")
+
+#         # ------------------------------------------------------------
+#         # Monthly limit
+#         # ------------------------------------------------------------
+#         if limits.monthly_limit is not None:
+#             monthly_count = await self._count_monthly_requests(user_id)
+#             if monthly_count > limits.monthly_limit:
+#                 logger.warning(
+#                     f"Monthly limit exceeded for user_id={user_id}, "
+#                     f"route={route}"
+#                 )
+#                 raise ValueError("Monthly rate limit exceeded")
+
+#     async def log_request(self, user_id: UUID, route: str):
+#         """
+#         Log a successful request to the request_log table.
+#         """
+#         query = f"""
+#         INSERT INTO {self._get_table_name(PostgresLimitsHandler.TABLE_NAME)}
+#         (time, user_id, route)
+#         VALUES (CURRENT_TIMESTAMP AT TIME ZONE 'UTC', $1, $2)
+#         """
+#         await self.connection_manager.execute_query(query, [user_id, route])