ratelimiter.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. import redis
  2. import logging
  3. class RateLimiter:
  4. """
  5. A rate limiter that uses Redis as a backend to store request counts.
  6. This class allows you to limit the number of requests made by a client
  7. (identified by a key) within a specified time window.
  8. """
  9. def __init__(self, redis_host: str, redis_port: int,
  10. time_window_sec: int, allowed_requests: int) -> None:
  11. """
  12. Initialises the RateLimiter instance.
  13. Parameters:
  14. redis_host (str): The Redis server hostname.
  15. redis_port (int): The Redis server port.
  16. time_window_sec (int): The time window (in seconds) in which
  17. requests are counted.
  18. allowed_requests (int): The maximum number of requests allowed
  19. in the time window.
  20. """
  21. self.__redis_client = redis.Redis(
  22. host=redis_host,
  23. port=redis_port,
  24. decode_responses=True
  25. )
  26. self.logger = logging.getLogger(__name__)
  27. self.logger.info(
  28. f"Redis connection established with {redis_host}:{redis_port}"
  29. )
  30. # Unique key prefix for this rate limiter instance
  31. self.__key_prefix = f"rl-{id(self)}-"
  32. self.__time_window_sec = time_window_sec
  33. self.__allowed_requests = allowed_requests
  34. self.logger.info(
  35. "RateLimiter initialized with parameters: "
  36. f"Key prefix: {self.__key_prefix}, "
  37. f"Time window: {self.__time_window_sec}s, "
  38. f"Allowed requests per window: {self.__allowed_requests}"
  39. )
  40. def __del__(self) -> None:
  41. """
  42. Clean up and close the Redis connection when the RateLimiter
  43. is deleted.
  44. """
  45. if self.__redis_client:
  46. self.__redis_client.close()
  47. self.logger.debug(
  48. f"Redis connection closed for RateLimiter with id {id(self)}"
  49. )
  50. def __get_prefixed_key(self, key: str) -> str:
  51. """
  52. Generates a unique key for Redis by adding a prefix.
  53. This helps avoid key collision in Redis with other data stored there.
  54. Parameters:
  55. key (str): The key (e.g., client identifier) to be used for rate
  56. limiting.
  57. Returns:
  58. str: The Redis key with the instance-specific prefix.
  59. """
  60. return self.__key_prefix + key
  61. def count(self, key: str) -> None:
  62. """
  63. Increment the request count for a specific key (e.g., an IP address)
  64. within the current time window.
  65. Parameters:
  66. key (str): The key for which the request count is being updated.
  67. For example, an IP address if rate limiting based on IPs.
  68. Raises:
  69. RateLimitExceededException: If the number of requests exceeds the
  70. allowed limit for the current time window.
  71. """
  72. self.logger.debug(f"Counting a request for key: {key}")
  73. pfx_key = self.__get_prefixed_key(key)
  74. # Check if the key already exists in Redis
  75. if self.__redis_client.exists(pfx_key):
  76. current_count = int(self.__redis_client.get(pfx_key))
  77. self.logger.debug(
  78. f"Current request count for '{pfx_key}': {current_count}"
  79. )
  80. # If request count exceeds the allowed limit, raise exception
  81. if current_count >= self.__allowed_requests:
  82. self.logger.warning(f"Rate limit exceeded for key '{pfx_key}'")
  83. raise RateLimitExceededException
  84. # Increment request count and keep TTL (time-to-live) unchanged
  85. self.__redis_client.set(
  86. name=pfx_key,
  87. value=(current_count + 1),
  88. keepttl=True
  89. )
  90. else:
  91. # Key doesn't exist yet, initialise count with TTL for time window
  92. self.logger.debug(
  93. f"No previous requests for key '{pfx_key}' in current window"
  94. ", initialising count to 1"
  95. )
  96. self.__redis_client.set(
  97. name=pfx_key,
  98. value=1,
  99. ex=self.__time_window_sec
  100. )
  101. class RateLimiterException(Exception):
  102. pass
  103. class RateLimitExceededException(RateLimiterException):
  104. def __init__(self, *args):
  105. message = "Too many requests. Try after some time."
  106. super().__init__(message)