puppet.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, TYPE_CHECKING, cast
  17. from aiohttp import ClientSession
  18. from yarl import URL
  19. from mauigpapi.types import BaseResponseUser
  20. from mautrix.bridge import BasePuppet
  21. from mautrix.appservice import IntentAPI
  22. from mautrix.types import ContentURI, UserID, SyncToken, RoomID
  23. from mautrix.util.simple_template import SimpleTemplate
  24. from .db import Puppet as DBPuppet
  25. from .config import Config
  26. from . import portal as p, user as u
  27. if TYPE_CHECKING:
  28. from .__main__ import InstagramBridge
  29. class Puppet(DBPuppet, BasePuppet):
  30. by_pk: Dict[int, 'Puppet'] = {}
  31. by_custom_mxid: Dict[UserID, 'Puppet'] = {}
  32. hs_domain: str
  33. mxid_template: SimpleTemplate[int]
  34. config: Config
  35. default_mxid_intent: IntentAPI
  36. default_mxid: UserID
  37. def __init__(self, pk: int, name: Optional[str] = None, username: Optional[str] = None,
  38. photo_id: Optional[str] = None, photo_mxc: Optional[ContentURI] = None,
  39. name_set: bool = False, avatar_set: bool = False, is_registered: bool = False,
  40. custom_mxid: Optional[UserID] = None, access_token: Optional[str] = None,
  41. next_batch: Optional[SyncToken] = None, base_url: Optional[URL] = None) -> None:
  42. super().__init__(pk=pk, name=name, username=username, photo_id=photo_id, name_set=name_set,
  43. photo_mxc=photo_mxc, avatar_set=avatar_set, is_registered=is_registered,
  44. custom_mxid=custom_mxid, access_token=access_token, next_batch=next_batch,
  45. base_url=base_url)
  46. self.log = self.log.getChild(str(pk))
  47. self.default_mxid = self.get_mxid_from_id(pk)
  48. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  49. self.intent = self._fresh_intent()
  50. @classmethod
  51. def init_cls(cls, bridge: 'InstagramBridge') -> AsyncIterable[Awaitable[None]]:
  52. cls.config = bridge.config
  53. cls.loop = bridge.loop
  54. cls.mx = bridge.matrix
  55. cls.az = bridge.az
  56. cls.hs_domain = cls.config["homeserver.domain"]
  57. cls.mxid_template = SimpleTemplate(cls.config["bridge.username_template"], "userid",
  58. prefix="@", suffix=f":{cls.hs_domain}", type=int)
  59. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  60. cls.homeserver_url_map = {server: URL(url) for server, url
  61. in cls.config["bridge.double_puppet_server_map"].items()}
  62. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  63. cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
  64. in cls.config["bridge.login_shared_secret_map"].items()}
  65. cls.login_device_name = "Instagram Bridge"
  66. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  67. @property
  68. def igpk(self) -> int:
  69. return self.pk
  70. def intent_for(self, portal: 'p.Portal') -> IntentAPI:
  71. if portal.other_user_pk == self.pk:
  72. return self.default_mxid_intent
  73. return self.intent
  74. def need_backfill_invite(self, portal: 'p.Portal') -> bool:
  75. return (portal.other_user_pk != self.pk and (self.is_real_user or portal.is_direct)
  76. and self.config["bridge.backfill.invite_own_puppet"])
  77. async def update_info(self, info: BaseResponseUser, source: 'u.User') -> None:
  78. update = False
  79. update = await self._update_name(info) or update
  80. update = await self._update_avatar(info, source) or update
  81. if update:
  82. await self.update()
  83. @classmethod
  84. def _get_displayname(cls, info: BaseResponseUser) -> str:
  85. return cls.config["bridge.displayname_template"].format(displayname=info.full_name,
  86. id=info.pk, username=info.username)
  87. async def _update_name(self, info: BaseResponseUser) -> bool:
  88. name = self._get_displayname(info)
  89. if name != self.name:
  90. self.name = name
  91. try:
  92. await self.default_mxid_intent.set_displayname(self.name)
  93. self.name_set = True
  94. except Exception:
  95. self.log.exception("Failed to update displayname")
  96. self.name_set = False
  97. return True
  98. return False
  99. async def _update_avatar(self, info: BaseResponseUser, source: 'u.User') -> bool:
  100. if info.profile_pic_id != self.photo_id or not self.avatar_set:
  101. self.photo_id = info.profile_pic_id
  102. if info.profile_pic_id:
  103. async with source.client.raw_http_get(info.profile_pic_url) as resp:
  104. content_type = resp.headers["Content-Type"]
  105. resp_data = await resp.read()
  106. mxc = await self.default_mxid_intent.upload_media(data=resp_data,
  107. mime_type=content_type,
  108. filename=info.profile_pic_id)
  109. else:
  110. mxc = None
  111. try:
  112. await self.default_mxid_intent.set_avatar_url(mxc)
  113. self.avatar_set = True
  114. self.photo_mxc = mxc
  115. except Exception:
  116. self.log.exception("Failed to update avatar")
  117. self.avatar_set = False
  118. return True
  119. return False
  120. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  121. portal = await p.Portal.get_by_mxid(room_id)
  122. return portal and portal.other_user_pk != self.pk
  123. # region Database getters
  124. def _add_to_cache(self) -> None:
  125. self.by_pk[self.pk] = self
  126. if self.custom_mxid:
  127. self.by_custom_mxid[self.custom_mxid] = self
  128. async def save(self) -> None:
  129. await self.update()
  130. @classmethod
  131. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
  132. pk = cls.get_id_from_mxid(mxid)
  133. if pk:
  134. return await cls.get_by_pk(pk, create)
  135. return None
  136. @classmethod
  137. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
  138. try:
  139. return cls.by_custom_mxid[mxid]
  140. except KeyError:
  141. pass
  142. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  143. if puppet:
  144. puppet._add_to_cache()
  145. return puppet
  146. return None
  147. @classmethod
  148. def get_id_from_mxid(cls, mxid: UserID) -> Optional[int]:
  149. return cls.mxid_template.parse(mxid)
  150. @classmethod
  151. def get_mxid_from_id(cls, twid: int) -> UserID:
  152. return UserID(cls.mxid_template.format_full(twid))
  153. @classmethod
  154. async def get_by_pk(cls, pk: int, create: bool = True) -> Optional['Puppet']:
  155. try:
  156. return cls.by_pk[pk]
  157. except KeyError:
  158. pass
  159. puppet = cast(cls, await super().get_by_pk(pk))
  160. if puppet is not None:
  161. puppet._add_to_cache()
  162. return puppet
  163. if create:
  164. puppet = cls(pk)
  165. await puppet.insert()
  166. puppet._add_to_cache()
  167. return puppet
  168. return None
  169. @classmethod
  170. async def all_with_custom_mxid(cls) -> AsyncGenerator['Puppet', None]:
  171. puppets = await super().all_with_custom_mxid()
  172. puppet: cls
  173. for index, puppet in enumerate(puppets):
  174. try:
  175. yield cls.by_pk[puppet.pk]
  176. except KeyError:
  177. puppet._add_to_cache()
  178. yield puppet
  179. # endregion