puppet.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, TYPE_CHECKING, cast
  17. from aiohttp import ClientSession
  18. from yarl import URL
  19. from mauigpapi.types import BaseResponseUser
  20. from mautrix.bridge import BasePuppet
  21. from mautrix.appservice import IntentAPI
  22. from mautrix.types import ContentURI, UserID, SyncToken, RoomID
  23. from mautrix.util.simple_template import SimpleTemplate
  24. from .db import Puppet as DBPuppet
  25. from .config import Config
  26. from . import portal as p
  27. if TYPE_CHECKING:
  28. from .__main__ import InstagramBridge
  29. class Puppet(DBPuppet, BasePuppet):
  30. by_pk: Dict[int, 'Puppet'] = {}
  31. by_custom_mxid: Dict[UserID, 'Puppet'] = {}
  32. hs_domain: str
  33. mxid_template: SimpleTemplate[int]
  34. config: Config
  35. default_mxid_intent: IntentAPI
  36. default_mxid: UserID
  37. def __init__(self, pk: int, name: Optional[str] = None, username: Optional[str] = None,
  38. photo_id: Optional[str] = None, photo_mxc: Optional[ContentURI] = None,
  39. name_set: bool = False, avatar_set: bool = False, is_registered: bool = False,
  40. custom_mxid: Optional[UserID] = None, access_token: Optional[str] = None,
  41. next_batch: Optional[SyncToken] = None, base_url: Optional[URL] = None) -> None:
  42. super().__init__(pk=pk, name=name, username=username, photo_id=photo_id, name_set=name_set,
  43. photo_mxc=photo_mxc, avatar_set=avatar_set, is_registered=is_registered,
  44. custom_mxid=custom_mxid, access_token=access_token, next_batch=next_batch,
  45. base_url=base_url)
  46. self.log = self.log.getChild(str(pk))
  47. self.default_mxid = self.get_mxid_from_id(pk)
  48. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  49. self.intent = self._fresh_intent()
  50. @classmethod
  51. def init_cls(cls, bridge: 'InstagramBridge') -> AsyncIterable[Awaitable[None]]:
  52. cls.config = bridge.config
  53. cls.loop = bridge.loop
  54. cls.mx = bridge.matrix
  55. cls.az = bridge.az
  56. cls.hs_domain = cls.config["homeserver.domain"]
  57. cls.mxid_template = SimpleTemplate(cls.config["bridge.username_template"], "userid",
  58. prefix="@", suffix=f":{cls.hs_domain}", type=int)
  59. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  60. cls.homeserver_url_map = {server: URL(url) for server, url
  61. in cls.config["bridge.double_puppet_server_map"].items()}
  62. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  63. cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
  64. in cls.config["bridge.login_shared_secret_map"].items()}
  65. cls.login_device_name = "Instagram Bridge"
  66. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  67. def intent_for(self, portal: 'p.Portal') -> IntentAPI:
  68. if portal.other_user_pk == self.pk or (self.config["bridge.backfill.invite_own_puppet"]
  69. and portal.backfill_lock.locked):
  70. return self.default_mxid_intent
  71. return self.intent
  72. async def update_info(self, info: BaseResponseUser) -> None:
  73. update = False
  74. update = await self._update_name(info) or update
  75. update = await self._update_avatar(info) or update
  76. if update:
  77. await self.update()
  78. @classmethod
  79. def _get_displayname(cls, info: BaseResponseUser) -> str:
  80. return cls.config["bridge.displayname_template"].format(displayname=info.full_name,
  81. id=info.pk, username=info.username)
  82. async def _update_name(self, info: BaseResponseUser) -> bool:
  83. name = self._get_displayname(info)
  84. if name != self.name:
  85. self.name = name
  86. try:
  87. await self.default_mxid_intent.set_displayname(self.name)
  88. self.name_set = True
  89. except Exception:
  90. self.log.exception("Failed to update displayname")
  91. self.name_set = False
  92. return True
  93. return False
  94. async def _update_avatar(self, info: BaseResponseUser) -> bool:
  95. if info.profile_pic_id != self.photo_id or not self.avatar_set:
  96. self.photo_id = info.profile_pic_id
  97. if info.profile_pic_id:
  98. # TODO if info.has_anonymous_profile_picture, we might need auth to get it
  99. # ...and we should probably download it with the device headers anyway
  100. async with ClientSession() as sess, sess.get(info.profile_pic_url) as resp:
  101. content_type = resp.headers["Content-Type"]
  102. resp_data = await resp.read()
  103. mxc = await self.default_mxid_intent.upload_media(data=resp_data,
  104. mime_type=content_type,
  105. filename=info.profile_pic_id)
  106. else:
  107. mxc = None
  108. try:
  109. await self.default_mxid_intent.set_avatar_url(mxc)
  110. self.avatar_set = True
  111. self.photo_mxc = mxc
  112. except Exception:
  113. self.log.exception("Failed to update avatar")
  114. self.avatar_set = False
  115. return True
  116. return False
  117. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  118. portal = await p.Portal.get_by_mxid(room_id)
  119. return portal and portal.other_user_pk != self.pk
  120. # region Database getters
  121. def _add_to_cache(self) -> None:
  122. self.by_pk[self.pk] = self
  123. if self.custom_mxid:
  124. self.by_custom_mxid[self.custom_mxid] = self
  125. async def save(self) -> None:
  126. await self.update()
  127. @classmethod
  128. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
  129. pk = cls.get_id_from_mxid(mxid)
  130. if pk:
  131. return await cls.get_by_pk(pk, create)
  132. return None
  133. @classmethod
  134. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
  135. try:
  136. return cls.by_custom_mxid[mxid]
  137. except KeyError:
  138. pass
  139. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  140. if puppet:
  141. puppet._add_to_cache()
  142. return puppet
  143. return None
  144. @classmethod
  145. def get_id_from_mxid(cls, mxid: UserID) -> Optional[int]:
  146. return cls.mxid_template.parse(mxid)
  147. @classmethod
  148. def get_mxid_from_id(cls, twid: int) -> UserID:
  149. return UserID(cls.mxid_template.format_full(twid))
  150. @classmethod
  151. async def get_by_pk(cls, pk: int, create: bool = True) -> Optional['Puppet']:
  152. try:
  153. return cls.by_pk[pk]
  154. except KeyError:
  155. pass
  156. puppet = cast(cls, await super().get_by_pk(pk))
  157. if puppet is not None:
  158. puppet._add_to_cache()
  159. return puppet
  160. if create:
  161. puppet = cls(pk)
  162. await puppet.insert()
  163. puppet._add_to_cache()
  164. return puppet
  165. return None
  166. @classmethod
  167. async def all_with_custom_mxid(cls) -> AsyncGenerator['Puppet', None]:
  168. puppets = await super().all_with_custom_mxid()
  169. puppet: cls
  170. for index, puppet in enumerate(puppets):
  171. try:
  172. yield cls.by_pk[puppet.pk]
  173. except KeyError:
  174. puppet._add_to_cache()
  175. yield puppet
  176. # endregion