puppet.py 8.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205
  1. # mautrix-instagram - A Matrix-Instagram puppeting bridge.
  2. # Copyright (C) 2020 Tulir Asokan
  3. #
  4. # This program is free software: you can redistribute it and/or modify
  5. # it under the terms of the GNU Affero General Public License as published by
  6. # the Free Software Foundation, either version 3 of the License, or
  7. # (at your option) any later version.
  8. #
  9. # This program is distributed in the hope that it will be useful,
  10. # but WITHOUT ANY WARRANTY; without even the implied warranty of
  11. # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
  12. # GNU Affero General Public License for more details.
  13. #
  14. # You should have received a copy of the GNU Affero General Public License
  15. # along with this program. If not, see <https://www.gnu.org/licenses/>.
  16. from typing import Optional, Dict, AsyncIterable, Awaitable, AsyncGenerator, TYPE_CHECKING, cast
  17. from aiohttp import ClientSession
  18. from yarl import URL
  19. from mauigpapi.types import BaseResponseUser
  20. from mautrix.bridge import BasePuppet
  21. from mautrix.appservice import IntentAPI
  22. from mautrix.types import ContentURI, UserID, SyncToken, RoomID
  23. from mautrix.util.simple_template import SimpleTemplate
  24. from .db import Puppet as DBPuppet
  25. from .config import Config
  26. from . import portal as p
  27. if TYPE_CHECKING:
  28. from .__main__ import InstagramBridge
  29. class Puppet(DBPuppet, BasePuppet):
  30. by_pk: Dict[int, 'Puppet'] = {}
  31. by_custom_mxid: Dict[UserID, 'Puppet'] = {}
  32. hs_domain: str
  33. mxid_template: SimpleTemplate[int]
  34. config: Config
  35. default_mxid_intent: IntentAPI
  36. default_mxid: UserID
  37. def __init__(self, pk: int, name: Optional[str] = None, username: Optional[str] = None,
  38. photo_id: Optional[str] = None, photo_mxc: Optional[ContentURI] = None,
  39. is_registered: bool = False, custom_mxid: Optional[UserID] = None,
  40. access_token: Optional[str] = None, next_batch: Optional[SyncToken] = None,
  41. base_url: Optional[URL] = None) -> None:
  42. super().__init__(pk=pk, name=name, username=username, photo_id=photo_id, photo_mxc=photo_mxc,
  43. is_registered=is_registered, custom_mxid=custom_mxid,
  44. access_token=access_token, next_batch=next_batch, base_url=base_url)
  45. self.log = self.log.getChild(str(pk))
  46. self.default_mxid = self.get_mxid_from_id(pk)
  47. self.default_mxid_intent = self.az.intent.user(self.default_mxid)
  48. self.intent = self._fresh_intent()
  49. @classmethod
  50. def init_cls(cls, bridge: 'InstagramBridge') -> AsyncIterable[Awaitable[None]]:
  51. cls.config = bridge.config
  52. cls.loop = bridge.loop
  53. cls.mx = bridge.matrix
  54. cls.az = bridge.az
  55. cls.hs_domain = cls.config["homeserver.domain"]
  56. cls.mxid_template = SimpleTemplate(cls.config["bridge.username_template"], "userid",
  57. prefix="@", suffix=f":{cls.hs_domain}", type=int)
  58. cls.sync_with_custom_puppets = cls.config["bridge.sync_with_custom_puppets"]
  59. cls.homeserver_url_map = {server: URL(url) for server, url
  60. in cls.config["bridge.double_puppet_server_map"].items()}
  61. cls.allow_discover_url = cls.config["bridge.double_puppet_allow_discovery"]
  62. cls.login_shared_secret_map = {server: secret.encode("utf-8") for server, secret
  63. in cls.config["bridge.login_shared_secret_map"].items()}
  64. cls.login_device_name = "Instagram Bridge"
  65. return (puppet.try_start() async for puppet in cls.all_with_custom_mxid())
  66. def intent_for(self, portal: 'p.Portal') -> IntentAPI:
  67. if portal.other_user_pk == self.pk or (self.config["bridge.backfill.invite_own_puppet"]
  68. and portal.backfill_lock.locked):
  69. return self.default_mxid_intent
  70. return self.intent
  71. async def update_info(self, info: BaseResponseUser) -> None:
  72. update = False
  73. update = await self._update_name(info) or update
  74. update = await self._update_avatar(info) or update
  75. if update:
  76. await self.update()
  77. @classmethod
  78. def _get_displayname(cls, info: BaseResponseUser) -> str:
  79. return cls.config["bridge.displayname_template"].format(displayname=info.full_name,
  80. id=info.pk, username=info.username)
  81. async def _update_name(self, info: BaseResponseUser) -> bool:
  82. name = self._get_displayname(info)
  83. if name != self.name:
  84. self.name = name
  85. try:
  86. await self.default_mxid_intent.set_displayname(self.name)
  87. self.name_set = True
  88. except Exception:
  89. self.log.exception("Failed to update displayname")
  90. self.name_set = False
  91. return True
  92. return False
  93. async def _update_avatar(self, info: BaseResponseUser) -> bool:
  94. if info.profile_pic_id != self.photo_id or not self.avatar_set:
  95. self.photo_id = info.profile_pic_id
  96. if info.profile_pic_id:
  97. async with ClientSession() as sess, sess.get(info.profile_pic_url) as resp:
  98. content_type = resp.headers["Content-Type"]
  99. resp_data = await resp.read()
  100. mxc = await self.default_mxid_intent.upload_media(data=resp_data,
  101. mime_type=content_type,
  102. filename=info.profile_pic_id)
  103. else:
  104. mxc = None
  105. try:
  106. await self.default_mxid_intent.set_avatar_url(mxc)
  107. self.avatar_set = True
  108. self.photo_mxc = mxc
  109. except Exception:
  110. self.log.exception("Failed to update avatar")
  111. self.avatar_set = False
  112. return True
  113. return False
  114. async def default_puppet_should_leave_room(self, room_id: RoomID) -> bool:
  115. portal = await p.Portal.get_by_mxid(room_id)
  116. return portal and portal.other_user_pk != self.pk
  117. # region Database getters
  118. def _add_to_cache(self) -> None:
  119. self.by_pk[self.pk] = self
  120. if self.custom_mxid:
  121. self.by_custom_mxid[self.custom_mxid] = self
  122. async def save(self) -> None:
  123. await self.update()
  124. @classmethod
  125. async def get_by_mxid(cls, mxid: UserID, create: bool = True) -> Optional['Puppet']:
  126. pk = cls.get_id_from_mxid(mxid)
  127. if pk:
  128. return await cls.get_by_pk(pk, create)
  129. return None
  130. @classmethod
  131. async def get_by_custom_mxid(cls, mxid: UserID) -> Optional['Puppet']:
  132. try:
  133. return cls.by_custom_mxid[mxid]
  134. except KeyError:
  135. pass
  136. puppet = cast(cls, await super().get_by_custom_mxid(mxid))
  137. if puppet:
  138. puppet._add_to_cache()
  139. return puppet
  140. return None
  141. @classmethod
  142. def get_id_from_mxid(cls, mxid: UserID) -> Optional[int]:
  143. return cls.mxid_template.parse(mxid)
  144. @classmethod
  145. def get_mxid_from_id(cls, twid: int) -> UserID:
  146. return UserID(cls.mxid_template.format_full(twid))
  147. @classmethod
  148. async def get_by_pk(cls, pk: int, create: bool = True) -> Optional['Puppet']:
  149. try:
  150. return cls.by_pk[pk]
  151. except KeyError:
  152. pass
  153. puppet = cast(cls, await super().get_by_pk(pk))
  154. if puppet is not None:
  155. puppet._add_to_cache()
  156. return puppet
  157. if create:
  158. puppet = cls(pk)
  159. await puppet.insert()
  160. puppet._add_to_cache()
  161. return puppet
  162. return None
  163. @classmethod
  164. async def all_with_custom_mxid(cls) -> AsyncGenerator['Puppet', None]:
  165. puppets = await super().all_with_custom_mxid()
  166. puppet: cls
  167. for index, puppet in enumerate(puppets):
  168. try:
  169. yield cls.by_pk[puppet.pk]
  170. except KeyError:
  171. puppet._add_to_cache()
  172. yield puppet
  173. # endregion