puppet.py 8.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, TYPE_CHECKING, cast
  17. from aiohttp import ClientSession
  18. from yarl import URL
  19. from mauigpapi.types import BaseResponseUser
  20. from mautrix.bridge import BasePuppet
  21. from mautrix.appservice import IntentAPI
  22. from mautrix.types import ContentURI, UserID, SyncToken, RoomID
  23. from mautrix.util.simple_template import SimpleTemplate
  24. from .db import Puppet as DBPuppet
  25. from .config import Config
  26. from . import portal as p
  27. if TYPE_CHECKING:
  28. from .__main__ import InstagramBridge
  29. class Puppet(DBPuppet, BasePuppet):
  30. by_pk: Dict[int, 'Puppet'] = {}
  31. by_custom_mxid: Dict[UserID, 'Puppet'] = {}
  32. hs_domain: str
  33. mxid_template: SimpleTemplate[int]
  34. config: Config
  35. default_mxid_intent: IntentAPI
  36. default_mxid: UserID
  37. def __init__(self, pk: int, name: Optional[str] = None, username: Optional[str] = None,
  38. photo_id: Optional[str] = None, photo_mxc: Optional[ContentURI] = None,
  39. name_set: bool = False, avatar_set: bool = False, is_registered: bool = False,
  40. custom_mxid: Optional[UserID] = None, access_token: Optional[str] = None,
  41. next_batch: Optional[SyncToken] = None, base_url: Optional[URL] = None) -> None:
  42. super().__init__(pk=pk, name=name, username=username, photo_id=photo_id, name_set=name_set,
  43. photo_mxc=photo_mxc, avatar_set=avatar_set, is_registered=is_registered,
  44. custom_mxid=custom_mxid, access_token=access_token, next_batch=next_batch,
  45. base_url=base_url)
  46. self.log = self.log.getChild(str(pk))
  47. self.default_mxid = self.get_mxid_from_id(pk)
  48. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  49. self.intent = self._fresh_intent()
  50. @classmethod
  51. def init_cls(cls, bridge: 'InstagramBridge') -> AsyncIterable[Awaitable[None]]:
  52. cls.config = bridge.config
  53. cls.loop = bridge.loop
  54. cls.mx = bridge.matrix
  55. cls.az = bridge.az
  56. cls.hs_domain = cls.config["homeserver.domain"]
  57. cls.mxid_template = SimpleTemplate(cls.config["bridge.username_template"], "userid",
  58. prefix="@", suffix=f":{cls.hs_domain}", type=int)
  59. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  60. cls.homeserver_url_map = {server: URL(url) for server, url
  61. in cls.config["bridge.double_puppet_server_map"].items()}
  62. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  63. cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
  64. in cls.config["bridge.login_shared_secret_map"].items()}
  65. cls.login_device_name = "Instagram Bridge"
  66. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  67. @property
  68. def igpk(self) -> int:
  69. return self.pk
  70. def intent_for(self, portal: 'p.Portal') -> IntentAPI:
  71. if portal.other_user_pk == self.pk:
  72. return self.default_mxid_intent
  73. return self.intent
  74. def need_backfill_invite(self, portal: 'p.Portal') -> bool:
  75. return (portal.other_user_pk != self.pk and self.is_real_user
  76. and self.config["bridge.backfill.invite_own_puppet"])
  77. async def update_info(self, info: BaseResponseUser) -> None:
  78. update = False
  79. update = await self._update_name(info) or update
  80. update = await self._update_avatar(info) or update
  81. if update:
  82. await self.update()
  83. @classmethod
  84. def _get_displayname(cls, info: BaseResponseUser) -> str:
  85. return cls.config["bridge.displayname_template"].format(displayname=info.full_name,
  86. id=info.pk, username=info.username)
  87. async def _update_name(self, info: BaseResponseUser) -> bool:
  88. name = self._get_displayname(info)
  89. if name != self.name:
  90. self.name = name
  91. try:
  92. await self.default_mxid_intent.set_displayname(self.name)
  93. self.name_set = True
  94. except Exception:
  95. self.log.exception("Failed to update displayname")
  96. self.name_set = False
  97. return True
  98. return False
  99. async def _update_avatar(self, info: BaseResponseUser) -> bool:
  100. if info.profile_pic_id != self.photo_id or not self.avatar_set:
  101. self.photo_id = info.profile_pic_id
  102. if info.profile_pic_id:
  103. # TODO if info.has_anonymous_profile_picture, we might need auth to get it
  104. # ...and we should probably download it with the device headers anyway
  105. async with ClientSession() as sess, sess.get(info.profile_pic_url) as resp:
  106. content_type = resp.headers["Content-Type"]
  107. resp_data = await resp.read()
  108. mxc = await self.default_mxid_intent.upload_media(data=resp_data,
  109. mime_type=content_type,
  110. filename=info.profile_pic_id)
  111. else:
  112. mxc = None
  113. try:
  114. await self.default_mxid_intent.set_avatar_url(mxc)
  115. self.avatar_set = True
  116. self.photo_mxc = mxc
  117. except Exception:
  118. self.log.exception("Failed to update avatar")
  119. self.avatar_set = False
  120. return True
  121. return False
  122. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  123. portal = await p.Portal.get_by_mxid(room_id)
  124. return portal and portal.other_user_pk != self.pk
  125. # region Database getters
  126. def _add_to_cache(self) -> None:
  127. self.by_pk[self.pk] = self
  128. if self.custom_mxid:
  129. self.by_custom_mxid[self.custom_mxid] = self
  130. async def save(self) -> None:
  131. await self.update()
  132. @classmethod
  133. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
  134. pk = cls.get_id_from_mxid(mxid)
  135. if pk:
  136. return await cls.get_by_pk(pk, create)
  137. return None
  138. @classmethod
  139. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
  140. try:
  141. return cls.by_custom_mxid[mxid]
  142. except KeyError:
  143. pass
  144. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  145. if puppet:
  146. puppet._add_to_cache()
  147. return puppet
  148. return None
  149. @classmethod
  150. def get_id_from_mxid(cls, mxid: UserID) -> Optional[int]:
  151. return cls.mxid_template.parse(mxid)
  152. @classmethod
  153. def get_mxid_from_id(cls, twid: int) -> UserID:
  154. return UserID(cls.mxid_template.format_full(twid))
  155. @classmethod
  156. async def get_by_pk(cls, pk: int, create: bool = True) -> Optional['Puppet']:
  157. try:
  158. return cls.by_pk[pk]
  159. except KeyError:
  160. pass
  161. puppet = cast(cls, await super().get_by_pk(pk))
  162. if puppet is not None:
  163. puppet._add_to_cache()
  164. return puppet
  165. if create:
  166. puppet = cls(pk)
  167. await puppet.insert()
  168. puppet._add_to_cache()
  169. return puppet
  170. return None
  171. @classmethod
  172. async def all_with_custom_mxid(cls) -> AsyncGenerator['Puppet', None]:
  173. puppets = await super().all_with_custom_mxid()
  174. puppet: cls
  175. for index, puppet in enumerate(puppets):
  176. try:
  177. yield cls.by_pk[puppet.pk]
  178. except KeyError:
  179. puppet._add_to_cache()
  180. yield puppet
  181. # endregion