|
@@ -0,0 +1,127 @@
|
|
|
|
|
+import redis
|
|
|
|
|
+import logging
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class RateLimiter:
|
|
|
|
|
+ """
|
|
|
|
|
+ A rate limiter that uses Redis as a backend to store request counts.
|
|
|
|
|
+
|
|
|
|
|
+ This class allows you to limit the number of requests made by a client
|
|
|
|
|
+ (identified by a key) within a specified time window.
|
|
|
|
|
+ """
|
|
|
|
|
+ def __init__(self, redis_host: str, redis_port: int,
|
|
|
|
|
+ time_window_sec: int, allowed_requests: int) -> None:
|
|
|
|
|
+ """
|
|
|
|
|
+ Initialises the RateLimiter instance.
|
|
|
|
|
+
|
|
|
|
|
+ Parameters:
|
|
|
|
|
+ redis_host (str): The Redis server hostname.
|
|
|
|
|
+ redis_port (int): The Redis server port.
|
|
|
|
|
+ time_window_sec (int): The time window (in seconds) in which
|
|
|
|
|
+ requests are counted.
|
|
|
|
|
+ allowed_requests (int): The maximum number of requests allowed
|
|
|
|
|
+ in the time window.
|
|
|
|
|
+ """
|
|
|
|
|
+ self.__redis_client = redis.Redis(
|
|
|
|
|
+ host=redis_host,
|
|
|
|
|
+ port=redis_port,
|
|
|
|
|
+ decode_responses=True
|
|
|
|
|
+ )
|
|
|
|
|
+ self.logger = logging.getLogger(__name__)
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ f"Redis connection established with {redis_host}:{redis_port}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # Unique key prefix for this rate limiter instance
|
|
|
|
|
+ self.__key_prefix = f"rl-{id(self)}-"
|
|
|
|
|
+ self.__time_window_sec = time_window_sec
|
|
|
|
|
+ self.__allowed_requests = allowed_requests
|
|
|
|
|
+
|
|
|
|
|
+ self.logger.info(
|
|
|
|
|
+ "RateLimiter initialized with parameters: "
|
|
|
|
|
+ f"Key prefix: {self.__key_prefix}, "
|
|
|
|
|
+ f"Time window: {self.__time_window_sec}s, "
|
|
|
|
|
+ f"Allowed requests per window: {self.__allowed_requests}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def __del__(self) -> None:
|
|
|
|
|
+ """
|
|
|
|
|
+ Clean up and close the Redis connection when the RateLimiter
|
|
|
|
|
+ is deleted.
|
|
|
|
|
+ """
|
|
|
|
|
+ if self.__redis_client:
|
|
|
|
|
+ self.__redis_client.close()
|
|
|
|
|
+ self.logger.debug(
|
|
|
|
|
+ f"Redis connection closed for RateLimiter with id {id(self)}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ def __get_prefixed_key(self, key: str) -> str:
|
|
|
|
|
+ """
|
|
|
|
|
+ Generates a unique key for Redis by adding a prefix.
|
|
|
|
|
+
|
|
|
|
|
+ This helps avoid key collision in Redis with other data stored there.
|
|
|
|
|
+
|
|
|
|
|
+ Parameters:
|
|
|
|
|
+ key (str): The key (e.g., client identifier) to be used for rate
|
|
|
|
|
+ limiting.
|
|
|
|
|
+
|
|
|
|
|
+ Returns:
|
|
|
|
|
+ str: The Redis key with the instance-specific prefix.
|
|
|
|
|
+ """
|
|
|
|
|
+ return self.__key_prefix + key
|
|
|
|
|
+
|
|
|
|
|
+ def count(self, key: str) -> None:
|
|
|
|
|
+ """
|
|
|
|
|
+ Increment the request count for a specific key (e.g., an IP address)
|
|
|
|
|
+ within the current time window.
|
|
|
|
|
+
|
|
|
|
|
+ Parameters:
|
|
|
|
|
+ key (str): The key for which the request count is being updated.
|
|
|
|
|
+ For example, an IP address if rate limiting based on IPs.
|
|
|
|
|
+
|
|
|
|
|
+ Raises:
|
|
|
|
|
+ RateLimitExceededException: If the number of requests exceeds the
|
|
|
|
|
+ allowed limit for the current time window.
|
|
|
|
|
+ """
|
|
|
|
|
+ self.logger.debug(f"Counting a request for key: {key}")
|
|
|
|
|
+ pfx_key = self.__get_prefixed_key(key)
|
|
|
|
|
+
|
|
|
|
|
+ # Check if the key already exists in Redis
|
|
|
|
|
+ if self.__redis_client.exists(pfx_key):
|
|
|
|
|
+ current_count = int(self.__redis_client.get(pfx_key))
|
|
|
|
|
+ self.logger.debug(
|
|
|
|
|
+ f"Current request count for '{pfx_key}': {current_count}"
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+ # If request count exceeds the allowed limit, raise exception
|
|
|
|
|
+ if current_count >= self.__allowed_requests:
|
|
|
|
|
+ self.logger.warning(f"Rate limit exceeded for key '{pfx_key}'")
|
|
|
|
|
+ raise RateLimitExceededException
|
|
|
|
|
+
|
|
|
|
|
+ # Increment request count and keep TTL (time-to-live) unchanged
|
|
|
|
|
+ self.__redis_client.set(
|
|
|
|
|
+ name=pfx_key,
|
|
|
|
|
+ value=(current_count + 1),
|
|
|
|
|
+ keepttl=True
|
|
|
|
|
+ )
|
|
|
|
|
+ else:
|
|
|
|
|
+ # Key doesn't exist yet, initialise count with TTL for time window
|
|
|
|
|
+ self.logger.debug(
|
|
|
|
|
+ f"No previous requests for key '{pfx_key}' in current window"
|
|
|
|
|
+ ", initialising count to 1"
|
|
|
|
|
+ )
|
|
|
|
|
+ self.__redis_client.set(
|
|
|
|
|
+ name=pfx_key,
|
|
|
|
|
+ value=1,
|
|
|
|
|
+ ex=self.__time_window_sec
|
|
|
|
|
+ )
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class RateLimiterException(Exception):
|
|
|
|
|
+ pass
|
|
|
|
|
+
|
|
|
|
|
+
|
|
|
|
|
+class RateLimitExceededException(RateLimiterException):
|
|
|
|
|
+ def __init__(self, *args):
|
|
|
|
|
+ message = "Too many requests. Try after some time."
|
|
|
|
|
+ super().__init__(message)
|