Ver Fonte

Fix NeedsRelaybot check for groups too

Tulir Asokan há 5 anos atrás
pai
commit
3caca1b9a0
3 ficheiros alterados com 7 adições e 7 exclusões
  1. 5 5
      database/user.go
  2. 1 1
      portal.go
  3. 1 1
      user.go

+ 5 - 5
database/user.go

@@ -201,11 +201,11 @@ func (user *User) SetPortalKeys(newKeys []PortalKeyWithMeta) error {
 	return tx.Commit()
 	return tx.Commit()
 }
 }
 
 
-func (user *User) IsInPortal(jid types.WhatsAppID) bool {
-	row := user.db.QueryRow(`SELECT portal_jid, portal_receiver FROM user_portal WHERE user_jid=$1 AND portal_jid=$2 AND (portal_receiver=$1 OR portal_receiver=$2)`, user.jidPtr(), &jid)
-	var scanJid, scanReceiver types.WhatsAppID
-	_ = row.Scan(&scanJid, &scanReceiver)
-	return scanJid == jid && (scanReceiver == jid || scanReceiver == user.JID)
+func (user *User) IsInPortal(key PortalKey) bool {
+	row := user.db.QueryRow(`SELECT EXISTS(SELECT 1 FROM user_portal WHERE user_jid=$1 AND portal_jid=$2 AND portal_receiver=$3)`, user.jidPtr(), &key.JID, &key.Receiver)
+	var exists bool
+	_ = row.Scan(&exists)
+	return exists
 }
 }
 
 
 func (user *User) GetPortalKeys() []PortalKey {
 func (user *User) GetPortalKeys() []PortalKey {

+ 1 - 1
portal.go

@@ -810,7 +810,7 @@ func (portal *Portal) HasRelaybot() bool {
 	if portal.bridge.Relaybot == nil {
 	if portal.bridge.Relaybot == nil {
 		return false
 		return false
 	} else if portal.hasRelaybot == nil {
 	} else if portal.hasRelaybot == nil {
-		val := portal.bridge.Relaybot.IsInPortal(portal.Key.JID)
+		val := portal.bridge.Relaybot.IsInPortal(portal.Key)
 		portal.hasRelaybot = &val
 		portal.hasRelaybot = &val
 	}
 	}
 	return *portal.hasRelaybot
 	return *portal.hasRelaybot

+ 1 - 1
user.go

@@ -781,5 +781,5 @@ func (user *User) HandleRawMessage(message *waProto.WebMessageInfo) {
 }
 }
 
 
 func (user *User) NeedsRelaybot(portal *Portal) bool {
 func (user *User) NeedsRelaybot(portal *Portal) bool {
-	return !user.HasSession() || (user.IsInPortal(portal.Key.JID) && (!portal.IsPrivateChat() || portal.Key.Receiver == user.JID))
+	return !user.HasSession() || !user.IsInPortal(portal.Key)
 }
 }