from __future__ import annotations from typing import Awaitable, Callable, TypeVar import asyncio import json import logging import urllib.request from aiohttp import ClientConnectionError from yarl import URL from mautrix.util.logging import TraceLogger try: from aiohttp_socks import ProxyConnectionError, ProxyError, ProxyTimeoutError except ImportError: class ProxyError(Exception): pass ProxyConnectionError = ProxyTimeoutError = ProxyError RETRYABLE_PROXY_EXCEPTIONS = ( ProxyError, ProxyTimeoutError, ProxyConnectionError, ClientConnectionError, ConnectionError, asyncio.TimeoutError, ) class ProxyHandler: current_proxy_url: str | None = None log = logging.getLogger("mauigpapi.proxy") def __init__(self, api_url: str | None) -> None: self.api_url = api_url def get_proxy_url_from_api(self, reason: str | None = None) -> str | None: assert self.api_url is not None api_url = str(URL(self.api_url).update_query({"reason": reason} if reason else {})) request = urllib.request.Request(api_url, method="GET") self.log.debug("Requesting proxy from: %s", api_url) try: with urllib.request.urlopen(request) as f: response = json.loads(f.read().decode()) except Exception: self.log.exception("Failed to retrieve proxy from API") else: return response["proxy_url"] return None def update_proxy_url(self, reason: str | None = None) -> bool: old_proxy = self.current_proxy_url new_proxy = None if self.api_url is not None: new_proxy = self.get_proxy_url_from_api(reason) else: new_proxy = urllib.request.getproxies().get("http") if old_proxy != new_proxy: self.log.debug("Set new proxy URL: %s", new_proxy) self.current_proxy_url = new_proxy return True self.log.debug("Got same proxy URL: %s", new_proxy) return False def get_proxy_url(self) -> str | None: if not self.current_proxy_url: self.update_proxy_url() return self.current_proxy_url T = TypeVar("T") async def proxy_with_retry( name: str, func: Callable[[], Awaitable[T]], logger: TraceLogger, proxy_handler: ProxyHandler, on_proxy_change: Callable[[], Awaitable[None]], max_retries: int = 10, ) -> T: errors = 0 while True: try: return await func() except RETRYABLE_PROXY_EXCEPTIONS as e: errors += 1 if errors > max_retries: raise wait = min(errors * 10, 60) logger.warning( "%s while trying to %s, retrying in %d seconds", e.__class__.__name__, name, wait, ) await asyncio.sleep(wait) if errors > 1 and proxy_handler.update_proxy_url( f"{e.__class__.__name__} while trying to {name}" ): await on_proxy_change()