proxy.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113
  1. from __future__ import annotations
  2. from typing import Awaitable, Callable, TypeVar
  3. import asyncio
  4. import json
  5. import logging
  6. import urllib.request
  7. from aiohttp import ClientConnectionError
  8. from yarl import URL
  9. from mautrix.util.logging import TraceLogger
  10. try:
  11. from aiohttp_socks import ProxyConnectionError, ProxyError, ProxyTimeoutError
  12. except ImportError:
  13. class ProxyError(Exception):
  14. pass
  15. ProxyConnectionError = ProxyTimeoutError = ProxyError
  16. RETRYABLE_PROXY_EXCEPTIONS = (
  17. ProxyError,
  18. ProxyTimeoutError,
  19. ProxyConnectionError,
  20. ClientConnectionError,
  21. ConnectionError,
  22. asyncio.TimeoutError,
  23. )
  24. class ProxyHandler:
  25. current_proxy_url: str | None = None
  26. log = logging.getLogger("mauigpapi.proxy")
  27. def __init__(self, api_url: str | None) -> None:
  28. self.api_url = api_url
  29. def get_proxy_url_from_api(self, reason: str | None = None) -> str | None:
  30. assert self.api_url is not None
  31. api_url = str(URL(self.api_url).update_query({"reason": reason} if reason else {}))
  32. request = urllib.request.Request(api_url, method="GET")
  33. self.log.debug("Requesting proxy from: %s", api_url)
  34. try:
  35. with urllib.request.urlopen(request) as f:
  36. response = json.loads(f.read().decode())
  37. except Exception:
  38. self.log.exception("Failed to retrieve proxy from API")
  39. else:
  40. return response["proxy_url"]
  41. return None
  42. def update_proxy_url(self, reason: str | None = None) -> bool:
  43. old_proxy = self.current_proxy_url
  44. new_proxy = None
  45. if self.api_url is not None:
  46. new_proxy = self.get_proxy_url_from_api(reason)
  47. else:
  48. new_proxy = urllib.request.getproxies().get("http")
  49. if old_proxy != new_proxy:
  50. self.log.debug("Set new proxy URL: %s", new_proxy)
  51. self.current_proxy_url = new_proxy
  52. return True
  53. self.log.debug("Got same proxy URL: %s", new_proxy)
  54. return False
  55. def get_proxy_url(self) -> str | None:
  56. if not self.current_proxy_url:
  57. self.update_proxy_url()
  58. return self.current_proxy_url
  59. T = TypeVar("T")
  60. async def proxy_with_retry(
  61. name: str,
  62. func: Callable[[], Awaitable[T]],
  63. logger: TraceLogger,
  64. proxy_handler: ProxyHandler,
  65. on_proxy_change: Callable[[], Awaitable[None]],
  66. max_retries: int = 10,
  67. ) -> T:
  68. errors = 0
  69. while True:
  70. try:
  71. return await func()
  72. except RETRYABLE_PROXY_EXCEPTIONS as e:
  73. errors += 1
  74. if errors > max_retries:
  75. raise
  76. wait = min(errors * 10, 60)
  77. logger.warning(
  78. "%s while trying to %s, retrying in %d seconds",
  79. e.__class__.__name__,
  80. name,
  81. wait,
  82. )
  83. await asyncio.sleep(wait)
  84. if errors > 1 and proxy_handler.update_proxy_url(
  85. f"{e.__class__.__name__} while trying to {name}"
  86. ):
  87. await on_proxy_change()