puppet.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, TYPE_CHECKING, cast
  17. from aiohttp import ClientSession
  18. from yarl import URL
  19. from mauigpapi.types import BaseResponseUser
  20. from mautrix.bridge import BasePuppet
  21. from mautrix.appservice import IntentAPI
  22. from mautrix.types import ContentURI, UserID, SyncToken, RoomID
  23. from mautrix.util.simple_template import SimpleTemplate
  24. from .db import Puppet as DBPuppet
  25. from .config import Config
  26. from . import portal as p
  27. if TYPE_CHECKING:
  28. from .__main__ import InstagramBridge
  29. class Puppet(DBPuppet, BasePuppet):
  30. by_pk: Dict[int, 'Puppet'] = {}
  31. by_custom_mxid: Dict[UserID, 'Puppet'] = {}
  32. hs_domain: str
  33. mxid_template: SimpleTemplate[int]
  34. config: Config
  35. default_mxid_intent: IntentAPI
  36. default_mxid: UserID
  37. def __init__(self, pk: int, name: Optional[str] = None, username: Optional[str] = None,
  38. photo_id: Optional[str] = None, photo_mxc: Optional[ContentURI] = None,
  39. name_set: bool = False, avatar_set: bool = False, is_registered: bool = False,
  40. custom_mxid: Optional[UserID] = None, access_token: Optional[str] = None,
  41. next_batch: Optional[SyncToken] = None, base_url: Optional[URL] = None) -> None:
  42. super().__init__(pk=pk, name=name, username=username, photo_id=photo_id, name_set=name_set,
  43. photo_mxc=photo_mxc, avatar_set=avatar_set, is_registered=is_registered,
  44. custom_mxid=custom_mxid, access_token=access_token, next_batch=next_batch,
  45. base_url=base_url)
  46. self.log = self.log.getChild(str(pk))
  47. self.default_mxid = self.get_mxid_from_id(pk)
  48. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  49. self.intent = self._fresh_intent()
  50. @classmethod
  51. def init_cls(cls, bridge: 'InstagramBridge') -> AsyncIterable[Awaitable[None]]:
  52. cls.config = bridge.config
  53. cls.loop = bridge.loop
  54. cls.mx = bridge.matrix
  55. cls.az = bridge.az
  56. cls.hs_domain = cls.config["homeserver.domain"]
  57. cls.mxid_template = SimpleTemplate(cls.config["bridge.username_template"], "userid",
  58. prefix="@", suffix=f":{cls.hs_domain}", type=int)
  59. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  60. cls.homeserver_url_map = {server: URL(url) for server, url
  61. in cls.config["bridge.double_puppet_server_map"].items()}
  62. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  63. cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
  64. in cls.config["bridge.login_shared_secret_map"].items()}
  65. cls.login_device_name = "Instagram Bridge"
  66. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  67. @property
  68. def igpk(self) -> int:
  69. return self.pk
  70. def intent_for(self, portal: 'p.Portal') -> IntentAPI:
  71. if portal.other_user_pk == self.pk or (self.config["bridge.backfill.invite_own_puppet"]
  72. and portal.backfill_lock.locked):
  73. return self.default_mxid_intent
  74. return self.intent
  75. async def update_info(self, info: BaseResponseUser) -> None:
  76. update = False
  77. update = await self._update_name(info) or update
  78. update = await self._update_avatar(info) or update
  79. if update:
  80. await self.update()
  81. @classmethod
  82. def _get_displayname(cls, info: BaseResponseUser) -> str:
  83. return cls.config["bridge.displayname_template"].format(displayname=info.full_name,
  84. id=info.pk, username=info.username)
  85. async def _update_name(self, info: BaseResponseUser) -> bool:
  86. name = self._get_displayname(info)
  87. if name != self.name:
  88. self.name = name
  89. try:
  90. await self.default_mxid_intent.set_displayname(self.name)
  91. self.name_set = True
  92. except Exception:
  93. self.log.exception("Failed to update displayname")
  94. self.name_set = False
  95. return True
  96. return False
  97. async def _update_avatar(self, info: BaseResponseUser) -> bool:
  98. if info.profile_pic_id != self.photo_id or not self.avatar_set:
  99. self.photo_id = info.profile_pic_id
  100. if info.profile_pic_id:
  101. # TODO if info.has_anonymous_profile_picture, we might need auth to get it
  102. # ...and we should probably download it with the device headers anyway
  103. async with ClientSession() as sess, sess.get(info.profile_pic_url) as resp:
  104. content_type = resp.headers["Content-Type"]
  105. resp_data = await resp.read()
  106. mxc = await self.default_mxid_intent.upload_media(data=resp_data,
  107. mime_type=content_type,
  108. filename=info.profile_pic_id)
  109. else:
  110. mxc = None
  111. try:
  112. await self.default_mxid_intent.set_avatar_url(mxc)
  113. self.avatar_set = True
  114. self.photo_mxc = mxc
  115. except Exception:
  116. self.log.exception("Failed to update avatar")
  117. self.avatar_set = False
  118. return True
  119. return False
  120. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  121. portal = await p.Portal.get_by_mxid(room_id)
  122. return portal and portal.other_user_pk != self.pk
  123. # region Database getters
  124. def _add_to_cache(self) -> None:
  125. self.by_pk[self.pk] = self
  126. if self.custom_mxid:
  127. self.by_custom_mxid[self.custom_mxid] = self
  128. async def save(self) -> None:
  129. await self.update()
  130. @classmethod
  131. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
  132. pk = cls.get_id_from_mxid(mxid)
  133. if pk:
  134. return await cls.get_by_pk(pk, create)
  135. return None
  136. @classmethod
  137. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
  138. try:
  139. return cls.by_custom_mxid[mxid]
  140. except KeyError:
  141. pass
  142. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  143. if puppet:
  144. puppet._add_to_cache()
  145. return puppet
  146. return None
  147. @classmethod
  148. def get_id_from_mxid(cls, mxid: UserID) -> Optional[int]:
  149. return cls.mxid_template.parse(mxid)
  150. @classmethod
  151. def get_mxid_from_id(cls, twid: int) -> UserID:
  152. return UserID(cls.mxid_template.format_full(twid))
  153. @classmethod
  154. async def get_by_pk(cls, pk: int, create: bool = True) -> Optional['Puppet']:
  155. try:
  156. return cls.by_pk[pk]
  157. except KeyError:
  158. pass
  159. puppet = cast(cls, await super().get_by_pk(pk))
  160. if puppet is not None:
  161. puppet._add_to_cache()
  162. return puppet
  163. if create:
  164. puppet = cls(pk)
  165. await puppet.insert()
  166. puppet._add_to_cache()
  167. return puppet
  168. return None
  169. @classmethod
  170. async def all_with_custom_mxid(cls) -> AsyncGenerator['Puppet', None]:
  171. puppets = await super().all_with_custom_mxid()
  172. puppet: cls
  173. for index, puppet in enumerate(puppets):
  174. try:
  175. yield cls.by_pk[puppet.pk]
  176. except KeyError:
  177. puppet._add_to_cache()
  178. yield puppet
  179. # endregion