puppet.go 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311
  1. package bridge
  2. import (
  3. "fmt"
  4. "io"
  5. "net/http"
  6. "regexp"
  7. "sync"
  8. log "maunium.net/go/maulogger/v2"
  9. "maunium.net/go/mautrix/appservice"
  10. "maunium.net/go/mautrix/id"
  11. "gitlab.com/beeper/discord/database"
  12. )
  13. type Puppet struct {
  14. *database.Puppet
  15. bridge *Bridge
  16. log log.Logger
  17. MXID id.UserID
  18. customIntent *appservice.IntentAPI
  19. customUser *User
  20. syncLock sync.Mutex
  21. }
  22. var userIDRegex *regexp.Regexp
  23. func (b *Bridge) NewPuppet(dbPuppet *database.Puppet) *Puppet {
  24. return &Puppet{
  25. Puppet: dbPuppet,
  26. bridge: b,
  27. log: b.log.Sub(fmt.Sprintf("Puppet/%s", dbPuppet.ID)),
  28. MXID: b.FormatPuppetMXID(dbPuppet.ID),
  29. }
  30. }
  31. func (b *Bridge) ParsePuppetMXID(mxid id.UserID) (string, bool) {
  32. if userIDRegex == nil {
  33. pattern := fmt.Sprintf(
  34. "^@%s:%s$",
  35. b.Config.Bridge.FormatUsername("([0-9]+)"),
  36. b.Config.Homeserver.Domain,
  37. )
  38. userIDRegex = regexp.MustCompile(pattern)
  39. }
  40. match := userIDRegex.FindStringSubmatch(string(mxid))
  41. if len(match) == 2 {
  42. return match[1], true
  43. }
  44. return "", false
  45. }
  46. func (b *Bridge) GetPuppetByMXID(mxid id.UserID) *Puppet {
  47. id, ok := b.ParsePuppetMXID(mxid)
  48. if !ok {
  49. return nil
  50. }
  51. return b.GetPuppetByID(id)
  52. }
  53. func (b *Bridge) GetPuppetByID(id string) *Puppet {
  54. b.puppetsLock.Lock()
  55. defer b.puppetsLock.Unlock()
  56. puppet, ok := b.puppets[id]
  57. if !ok {
  58. dbPuppet := b.db.Puppet.Get(id)
  59. if dbPuppet == nil {
  60. dbPuppet = b.db.Puppet.New()
  61. dbPuppet.ID = id
  62. dbPuppet.Insert()
  63. }
  64. puppet = b.NewPuppet(dbPuppet)
  65. b.puppets[puppet.ID] = puppet
  66. }
  67. return puppet
  68. }
  69. func (b *Bridge) GetPuppetByCustomMXID(mxid id.UserID) *Puppet {
  70. b.puppetsLock.Lock()
  71. defer b.puppetsLock.Unlock()
  72. puppet, ok := b.puppetsByCustomMXID[mxid]
  73. if !ok {
  74. dbPuppet := b.db.Puppet.GetByCustomMXID(mxid)
  75. if dbPuppet == nil {
  76. return nil
  77. }
  78. puppet = b.NewPuppet(dbPuppet)
  79. b.puppets[puppet.ID] = puppet
  80. b.puppetsByCustomMXID[puppet.CustomMXID] = puppet
  81. }
  82. return puppet
  83. }
  84. func (b *Bridge) GetAllPuppetsWithCustomMXID() []*Puppet {
  85. return b.dbPuppetsToPuppets(b.db.Puppet.GetAllWithCustomMXID())
  86. }
  87. func (b *Bridge) GetAllPuppets() []*Puppet {
  88. return b.dbPuppetsToPuppets(b.db.Puppet.GetAll())
  89. }
  90. func (b *Bridge) dbPuppetsToPuppets(dbPuppets []*database.Puppet) []*Puppet {
  91. b.puppetsLock.Lock()
  92. defer b.puppetsLock.Unlock()
  93. output := make([]*Puppet, len(dbPuppets))
  94. for index, dbPuppet := range dbPuppets {
  95. if dbPuppet == nil {
  96. continue
  97. }
  98. puppet, ok := b.puppets[dbPuppet.ID]
  99. if !ok {
  100. puppet = b.NewPuppet(dbPuppet)
  101. b.puppets[dbPuppet.ID] = puppet
  102. if dbPuppet.CustomMXID != "" {
  103. b.puppetsByCustomMXID[dbPuppet.CustomMXID] = puppet
  104. }
  105. }
  106. output[index] = puppet
  107. }
  108. return output
  109. }
  110. func (b *Bridge) FormatPuppetMXID(did string) id.UserID {
  111. return id.NewUserID(
  112. b.Config.Bridge.FormatUsername(did),
  113. b.Config.Homeserver.Domain,
  114. )
  115. }
  116. func (p *Puppet) DefaultIntent() *appservice.IntentAPI {
  117. return p.bridge.as.Intent(p.MXID)
  118. }
  119. func (p *Puppet) IntentFor(portal *Portal) *appservice.IntentAPI {
  120. if p.customIntent == nil {
  121. return p.DefaultIntent()
  122. }
  123. return p.customIntent
  124. }
  125. func (p *Puppet) CustomIntent() *appservice.IntentAPI {
  126. return p.customIntent
  127. }
  128. func (p *Puppet) updatePortalMeta(meta func(portal *Portal)) {
  129. for _, portal := range p.bridge.GetAllPortalsByID(p.ID) {
  130. meta(portal)
  131. }
  132. }
  133. func (p *Puppet) updateName(source *User) bool {
  134. user, err := source.Session.User(p.ID)
  135. if err != nil {
  136. p.log.Warnln("failed to get user from id:", err)
  137. return false
  138. }
  139. newName := p.bridge.Config.Bridge.FormatDisplayname(user)
  140. if p.DisplayName != newName {
  141. err := p.DefaultIntent().SetDisplayName(newName)
  142. if err == nil {
  143. p.DisplayName = newName
  144. go p.updatePortalName()
  145. p.Update()
  146. } else {
  147. p.log.Warnln("failed to set display name:", err)
  148. }
  149. return true
  150. }
  151. return false
  152. }
  153. func (p *Puppet) updatePortalName() {
  154. p.updatePortalMeta(func(portal *Portal) {
  155. if portal.MXID != "" {
  156. _, err := portal.MainIntent().SetRoomName(portal.MXID, p.DisplayName)
  157. if err != nil {
  158. portal.log.Warnln("Failed to set name:", err)
  159. }
  160. }
  161. portal.Name = p.DisplayName
  162. portal.Update()
  163. })
  164. }
  165. func (p *Puppet) uploadAvatar(intent *appservice.IntentAPI, url string) (id.ContentURI, error) {
  166. getResp, err := http.DefaultClient.Get(url)
  167. if err != nil {
  168. return id.ContentURI{}, fmt.Errorf("failed to download avatar: %w", err)
  169. }
  170. data, err := io.ReadAll(getResp.Body)
  171. getResp.Body.Close()
  172. if err != nil {
  173. return id.ContentURI{}, fmt.Errorf("failed to read avatar data: %w", err)
  174. }
  175. mime := http.DetectContentType(data)
  176. resp, err := intent.UploadBytes(data, mime)
  177. if err != nil {
  178. return id.ContentURI{}, fmt.Errorf("failed to upload avatar to Matrix: %w", err)
  179. }
  180. return resp.ContentURI, nil
  181. }
  182. func (p *Puppet) updateAvatar(source *User) bool {
  183. user, err := source.Session.User(p.ID)
  184. if err != nil {
  185. p.log.Warnln("Failed to get user:", err)
  186. return false
  187. }
  188. if p.Avatar == user.Avatar {
  189. return false
  190. }
  191. if user.Avatar == "" {
  192. p.log.Warnln("User does not have an avatar")
  193. return false
  194. }
  195. url, err := p.uploadAvatar(p.DefaultIntent(), user.AvatarURL(""))
  196. if err != nil {
  197. p.log.Warnln("Failed to upload user avatar:", err)
  198. return false
  199. }
  200. p.AvatarURL = url
  201. err = p.DefaultIntent().SetAvatarURL(p.AvatarURL)
  202. if err != nil {
  203. p.log.Warnln("Failed to set avatar:", err)
  204. }
  205. p.log.Debugln("Updated avatar", p.Avatar, "->", user.Avatar)
  206. p.Avatar = user.Avatar
  207. go p.updatePortalAvatar()
  208. return true
  209. }
  210. func (p *Puppet) updatePortalAvatar() {
  211. p.updatePortalMeta(func(portal *Portal) {
  212. if portal.MXID != "" {
  213. _, err := portal.MainIntent().SetRoomAvatar(portal.MXID, p.AvatarURL)
  214. if err != nil {
  215. portal.log.Warnln("Failed to set avatar:", err)
  216. }
  217. }
  218. portal.AvatarURL = p.AvatarURL
  219. portal.Avatar = p.Avatar
  220. portal.Update()
  221. })
  222. }
  223. func (p *Puppet) SyncContact(source *User) {
  224. p.syncLock.Lock()
  225. defer p.syncLock.Unlock()
  226. p.log.Debugln("syncing contact", p.DisplayName)
  227. err := p.DefaultIntent().EnsureRegistered()
  228. if err != nil {
  229. p.log.Errorln("Failed to ensure registered:", err)
  230. }
  231. update := false
  232. update = p.updateName(source) || update
  233. if p.Avatar == "" {
  234. update = p.updateAvatar(source) || update
  235. p.log.Debugln("update avatar returned", update)
  236. }
  237. if update {
  238. p.Update()
  239. }
  240. }