Tulir Asokan 3 жил өмнө
parent
commit
7c0cf0513a
6 өөрчлөгдсөн 50 нэмэгдсэн , 51 устгасан
  1. 2 1
      config/config.go
  2. 3 1
      config/upgrade.go
  3. 3 5
      example-config.yaml
  4. 6 0
      main.go
  5. 10 14
      provisioning.go
  6. 26 30
      segment.go

+ 2 - 1
config/config.go

@@ -47,7 +47,6 @@ type Config struct {
 		Provisioning struct {
 			Prefix       string `yaml:"prefix"`
 			SharedSecret string `yaml:"shared_secret"`
-			SegmentKey   string `yaml:"segment_key"`
 		} `yaml:"provisioning"`
 
 		ID  string `yaml:"id"`
@@ -65,6 +64,8 @@ type Config struct {
 		HSToken string `yaml:"hs_token"`
 	} `yaml:"appservice"`
 
+	SegmentKey string `yaml:"segment_key"`
+
 	Metrics struct {
 		Enabled bool   `yaml:"enabled"`
 		Listen  string `yaml:"listen"`

+ 3 - 1
config/upgrade.go

@@ -55,7 +55,6 @@ func (helper *UpgradeHelper) doUpgrade() {
 	} else {
 		helper.Copy(Str, "appservice", "provisioning", "shared_secret")
 	}
-	helper.Copy(Str|Null, "appservice", "provisioning", "segment_key")
 	helper.Copy(Str, "appservice", "id")
 	helper.Copy(Str, "appservice", "bot", "username")
 	helper.Copy(Str, "appservice", "bot", "displayname")
@@ -64,6 +63,8 @@ func (helper *UpgradeHelper) doUpgrade() {
 	helper.Copy(Str, "appservice", "as_token")
 	helper.Copy(Str, "appservice", "hs_token")
 
+	helper.Copy(Str|Null, "segment_key")
+
 	helper.Copy(Bool, "metrics", "enabled")
 	helper.Copy(Str, "metrics", "listen")
 
@@ -170,6 +171,7 @@ func (helper *UpgradeHelper) addSpaces() {
 	helper.addSpaceBeforeComment("appservice", "provisioning")
 	helper.addSpaceBeforeComment("appservice", "id")
 	helper.addSpaceBeforeComment("appservice", "as_token")
+	helper.addSpaceBeforeComment("segment_key")
 	helper.addSpaceBeforeComment("metrics")
 	helper.addSpaceBeforeComment("whatsapp")
 	helper.addSpaceBeforeComment("bridge")

+ 3 - 5
example-config.yaml

@@ -50,11 +50,6 @@ appservice:
         # Shared secret for authentication. If set to "generate", a random secret will be generated,
         # or if set to "disable", the provisioning API will be disabled.
         shared_secret: generate
-        # Segment API key to enable analytics tracking for web server
-        # endpoints. Set to null to disable.
-        # Currently the only events are login start, QR code retrieve, and login
-        # success/failure.
-        segment_key: null
 
     # The unique ID of this appservice.
     id: whatsapp
@@ -76,6 +71,9 @@ appservice:
     as_token: "This value is generated when generating the registration"
     hs_token: "This value is generated when generating the registration"
 
+# Segment API key to track some events, like provisioning API login and encryption errors.
+segment_key: null
+
 # Prometheus config.
 metrics:
     # Enable prometheus metrics?

+ 6 - 0
main.go

@@ -272,6 +272,12 @@ func (bridge *Bridge) Init() {
 	bridge.StateStore = database.NewSQLStateStore(bridge.DB)
 	bridge.AS.StateStore = bridge.StateStore
 
+	Segment.log = bridge.Log.Sub("Segment")
+	Segment.key = bridge.Config.SegmentKey
+	if Segment.IsEnabled() {
+		Segment.log.Infoln("Segment metrics are enabled")
+	}
+
 	bridge.WAContainer = sqlstore.NewWithDB(bridge.DB.DB, bridge.Config.AppService.Database.Type, nil)
 	bridge.WAContainer.DatabaseErrorHandler = bridge.DB.HandleSignalStoreError
 

+ 10 - 14
provisioning.go

@@ -41,17 +41,13 @@ import (
 )
 
 type ProvisioningAPI struct {
-	bridge  *Bridge
-	log     log.Logger
-	segment *Segment
+	bridge *Bridge
+	log    log.Logger
 }
 
 func (prov *ProvisioningAPI) Init() {
 	prov.log = prov.bridge.Log.Sub("Provisioning")
 
-	// Set up segment
-	prov.segment = NewSegment(prov.bridge.Config.AppService.Provisioning.SegmentKey, prov.log)
-
 	prov.log.Debugln("Enabling provisioning API at", prov.bridge.Config.AppService.Provisioning.Prefix)
 	r := prov.bridge.AS.Router.PathPrefix(prov.bridge.Config.AppService.Provisioning.Prefix).Subrouter()
 	r.Use(prov.AuthMiddleware)
@@ -573,7 +569,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 		}
 	}
 	user.log.Debugln("Started login via provisioning API")
-	prov.segment.Track(user.MXID, "$login_start")
+	Segment.Track(user.MXID, "$login_start")
 
 	for {
 		select {
@@ -582,7 +578,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 			case whatsmeow.QRChannelSuccess.Event:
 				jid := user.Client.Store.ID
 				user.log.Debugln("Successful login as", jid, "via provisioning API")
-				prov.segment.Track(user.MXID, "$login_success")
+				Segment.Track(user.MXID, "$login_success")
 				_ = c.WriteJSON(map[string]interface{}{
 					"success":  true,
 					"jid":      jid,
@@ -597,7 +593,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 			case whatsmeow.QRChannelTimeout.Event:
 				user.log.Debugln("Login via provisioning API timed out")
 				errCode := "login timed out"
-				prov.segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
+				Segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
 				_ = c.WriteJSON(Error{
 					Error:   "QR code scan timed out. Please try again.",
 					ErrCode: errCode,
@@ -605,7 +601,7 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 			case whatsmeow.QRChannelErrUnexpectedEvent.Event:
 				user.log.Debugln("Login via provisioning API failed due to unexpected event")
 				errCode := "unexpected event"
-				prov.segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
+				Segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
 				_ = c.WriteJSON(Error{
 					Error:   "Got unexpected event while waiting for QRs, perhaps you're already logged in?",
 					ErrCode: errCode,
@@ -613,14 +609,14 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 			case whatsmeow.QRChannelClientOutdated.Event:
 				user.log.Debugln("Login via provisioning API failed due to outdated client")
 				errCode := "bridge outdated"
-				prov.segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
+				Segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
 				_ = c.WriteJSON(Error{
 					Error:   "Got client outdated error while waiting for QRs. The bridge must be updated to continue.",
 					ErrCode: errCode,
 				})
 			case whatsmeow.QRChannelScannedWithoutMultidevice.Event:
 				errCode := "multidevice not enabled"
-				prov.segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
+				Segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
 				_ = c.WriteJSON(Error{
 					Error:   "Please enable the WhatsApp multidevice beta and scan the QR code again.",
 					ErrCode: errCode,
@@ -628,13 +624,13 @@ func (prov *ProvisioningAPI) Login(w http.ResponseWriter, r *http.Request) {
 				continue
 			case "error":
 				errCode := "fatal error"
-				prov.segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
+				Segment.Track(user.MXID, "$login_failure", map[string]interface{}{"error": errCode})
 				_ = c.WriteJSON(Error{
 					Error:   "Fatal error while logging in",
 					ErrCode: errCode,
 				})
 			case "code":
-				prov.segment.Track(user.MXID, "$qrcode_retrieved")
+				Segment.Track(user.MXID, "$qrcode_retrieved")
 				_ = c.WriteJSON(map[string]interface{}{
 					"code":    evt.Code,
 					"timeout": int(evt.Timeout.Seconds()),

+ 26 - 30
segment.go

@@ -27,37 +27,31 @@ import (
 
 const SegmentURL = "https://api.segment.io/v1/track"
 
-type Segment struct {
-	segmentKey string
-	log        log.Logger
-	client     *http.Client
+type SegmentClient struct {
+	key    string
+	log    log.Logger
+	client http.Client
 }
 
-func NewSegment(segmentKey string, parentLogger log.Logger) *Segment {
-	return &Segment{
-		segmentKey: segmentKey,
-		log:        parentLogger.Sub("Segment"),
-		client:     &http.Client{},
-	}
-}
+var Segment SegmentClient
 
-func (segment *Segment) track(userID id.UserID, event string, properties map[string]interface{}) error {
-	data := map[string]interface{}{
+func (sc *SegmentClient) trackSync(userID id.UserID, event string, properties map[string]interface{}) error {
+	var buf bytes.Buffer
+	err := json.NewEncoder(&buf).Encode(map[string]interface{}{
 		"userId":     userID,
 		"event":      event,
 		"properties": properties,
-	}
-	json_data, err := json.Marshal(data)
+	})
 	if err != nil {
 		return err
 	}
 
-	req, err := http.NewRequest("POST", SegmentURL, bytes.NewBuffer(json_data))
+	req, err := http.NewRequest("POST", SegmentURL, &buf)
 	if err != nil {
 		return err
 	}
-	req.SetBasicAuth(segment.segmentKey, "")
-	resp, err := segment.client.Do(req)
+	req.SetBasicAuth(sc.key, "")
+	resp, err := sc.client.Do(req)
 	if err != nil {
 		return err
 	}
@@ -65,26 +59,28 @@ func (segment *Segment) track(userID id.UserID, event string, properties map[str
 	return nil
 }
 
-func (segment *Segment) Track(userID id.UserID, event string, properties ...map[string]interface{}) {
-	if segment.segmentKey == "" {
+func (sc *SegmentClient) IsEnabled() bool {
+	return len(sc.key) > 0
+}
+
+func (sc *SegmentClient) Track(userID id.UserID, event string, properties ...map[string]interface{}) {
+	if !sc.IsEnabled() {
 		return
-	}
-	if len(properties) > 1 {
-		segment.log.Fatalf("Track should be called with at most one property map")
+	} else if len(properties) > 1 {
+		panic("Track should be called with at most one property map")
 	}
 
-	go (func() error {
+	go func() {
 		props := map[string]interface{}{}
 		if len(properties) > 0 {
 			props = properties[0]
 		}
 		props["bridge"] = "whatsapp"
-		err := segment.track(userID, event, props)
+		err := sc.trackSync(userID, event, props)
 		if err != nil {
-			segment.log.Errorf("Error tracking %s: %v+", event, err)
-			return err
+			sc.log.Errorfln("Error tracking %s: %v", event, err)
+		} else {
+			sc.log.Debugln("Tracked", event)
 		}
-		segment.log.Debug("Tracked ", event)
-		return nil
-	})()
+	}()
 }