123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113 |
- from __future__ import annotations
- from typing import Awaitable, Callable, TypeVar
- import asyncio
- import json
- import logging
- import urllib.request
- from aiohttp import ClientConnectionError
- from yarl import URL
- from mautrix.util.logging import TraceLogger
- try:
- from aiohttp_socks import ProxyConnectionError, ProxyError, ProxyTimeoutError
- except ImportError:
- class ProxyError(Exception):
- pass
- ProxyConnectionError = ProxyTimeoutError = ProxyError
- RETRYABLE_PROXY_EXCEPTIONS = (
- ProxyError,
- ProxyTimeoutError,
- ProxyConnectionError,
- ClientConnectionError,
- ConnectionError,
- asyncio.TimeoutError,
- )
- class ProxyHandler:
- current_proxy_url: str | None = None
- log = logging.getLogger("mauigpapi.proxy")
- def __init__(self, api_url: str | None) -> None:
- self.api_url = api_url
- def get_proxy_url_from_api(self, reason: str | None = None) -> str | None:
- assert self.api_url is not None
- api_url = str(URL(self.api_url).update_query({"reason": reason} if reason else {}))
- request = urllib.request.Request(api_url, method="GET")
- self.log.debug("Requesting proxy from: %s", api_url)
- try:
- with urllib.request.urlopen(request) as f:
- response = json.loads(f.read().decode())
- except Exception:
- self.log.exception("Failed to retrieve proxy from API")
- else:
- return response["proxy_url"]
- return None
- def update_proxy_url(self, reason: str | None = None) -> bool:
- old_proxy = self.current_proxy_url
- new_proxy = None
- if self.api_url is not None:
- new_proxy = self.get_proxy_url_from_api(reason)
- else:
- new_proxy = urllib.request.getproxies().get("http")
- if old_proxy != new_proxy:
- self.log.debug("Set new proxy URL: %s", new_proxy)
- self.current_proxy_url = new_proxy
- return True
- self.log.debug("Got same proxy URL: %s", new_proxy)
- return False
- def get_proxy_url(self) -> str | None:
- if not self.current_proxy_url:
- self.update_proxy_url()
- return self.current_proxy_url
- T = TypeVar("T")
- async def proxy_with_retry(
- name: str,
- func: Callable[[], Awaitable[T]],
- logger: TraceLogger,
- proxy_handler: ProxyHandler,
- on_proxy_change: Callable[[], Awaitable[None]],
- max_retries: int = 10,
- ) -> T:
- errors = 0
- while True:
- try:
- return await func()
- except RETRYABLE_PROXY_EXCEPTIONS as e:
- errors += 1
- if errors > max_retries:
- raise
- wait = min(errors * 10, 60)
- logger.warning(
- "%s while trying to %s, retrying in %d seconds",
- e.__class__.__name__,
- name,
- wait,
- )
- await asyncio.sleep(wait)
- if errors > 1 and proxy_handler.update_proxy_url(
- f"{e.__class__.__name__} while trying to {name}"
- ):
- await on_proxy_change()
|