diff --git a/auth/host_session.go b/auth/host_session.go index 0bc7000a..7b74877a 100644 --- a/auth/host_session.go +++ b/auth/host_session.go @@ -132,11 +132,11 @@ func SessionHandler(conn *websocket.Conn) { hostPass := result.Host.HostPass if !logic.HostExists(&result.Host) { // check if host already exists, add if not if servercfg.GetBrokerType() == servercfg.EmqxBrokerType { - if err := mq.CreateEmqxUser(result.Host.ID.String(), result.Host.HostPass, false); err != nil { + if err := mq.GetEmqxHandler().CreateEmqxUser(result.Host.ID.String(), result.Host.HostPass, false); err != nil { logger.Log(0, "failed to create host credentials for EMQX: ", err.Error()) return } - if err := mq.CreateHostACL(result.Host.ID.String(), servercfg.GetServerInfo().Server); err != nil { + if err := mq.GetEmqxHandler().CreateHostACL(result.Host.ID.String(), servercfg.GetServerInfo().Server); err != nil { logger.Log(0, "failed to add host ACL rules to EMQX: ", err.Error()) return } diff --git a/controllers/enrollmentkeys.go b/controllers/enrollmentkeys.go index b75d7b04..075db8fb 100644 --- a/controllers/enrollmentkeys.go +++ b/controllers/enrollmentkeys.go @@ -312,11 +312,11 @@ func handleHostRegister(w http.ResponseWriter, r *http.Request) { logic.CheckHostPorts(&newHost) // create EMQX credentials and ACLs for host if servercfg.GetBrokerType() == servercfg.EmqxBrokerType { - if err := mq.CreateEmqxUser(newHost.ID.String(), newHost.HostPass, false); err != nil { + if err := mq.GetEmqxHandler().CreateEmqxUser(newHost.ID.String(), newHost.HostPass, false); err != nil { logger.Log(0, "failed to create host credentials for EMQX: ", err.Error()) return } - if err := mq.CreateHostACL(newHost.ID.String(), servercfg.GetServerInfo().Server); err != nil { + if err := mq.GetEmqxHandler().CreateHostACL(newHost.ID.String(), servercfg.GetServerInfo().Server); err != nil { logger.Log(0, "failed to add host ACL rules to EMQX: ", err.Error()) return } diff --git a/controllers/hosts.go b/controllers/hosts.go index 6d10a6f1..08d8a200 100644 --- a/controllers/hosts.go +++ b/controllers/hosts.go @@ -298,7 +298,7 @@ func deleteHost(w http.ResponseWriter, r *http.Request) { } if servercfg.GetBrokerType() == servercfg.EmqxBrokerType { // delete EMQX credentials for host - if err := mq.DeleteEmqxUser(currHost.ID.String()); err != nil { + if err := mq.GetEmqxHandler().DeleteEmqxUser(currHost.ID.String()); err != nil { slog.Error("failed to remove host credentials from EMQX", "id", currHost.ID, "error", err) } } @@ -549,15 +549,15 @@ func authenticateHost(response http.ResponseWriter, request *http.Request) { // Create EMQX creds and ACLs if not found if servercfg.GetBrokerType() == servercfg.EmqxBrokerType { - if err := mq.CreateEmqxUser(host.ID.String(), authRequest.Password, false); err != nil { + if err := mq.GetEmqxHandler().CreateEmqxUser(host.ID.String(), authRequest.Password, false); err != nil { slog.Error("failed to create host credentials for EMQX: ", err.Error()) } else { - if err := mq.CreateHostACL(host.ID.String(), servercfg.GetServerInfo().Server); err != nil { + if err := mq.GetEmqxHandler().CreateHostACL(host.ID.String(), servercfg.GetServerInfo().Server); err != nil { slog.Error("failed to add host ACL rules to EMQX: ", err.Error()) } for _, nodeID := range host.Nodes { if node, err := logic.GetNodeByID(nodeID); err == nil { - if err = mq.AppendNodeUpdateACL(host.ID.String(), node.Network, node.ID.String(), servercfg.GetServer()); err != nil { + if err = mq.GetEmqxHandler().AppendNodeUpdateACL(host.ID.String(), node.Network, node.ID.String(), servercfg.GetServer()); err != nil { slog.Error("failed to add ACLs for EMQX node", "error", err) } } else { diff --git a/mq/emqx.go b/mq/emqx.go index 9727fae2..3e29f483 100644 --- a/mq/emqx.go +++ b/mq/emqx.go @@ -1,386 +1,40 @@ package mq -import ( - "bytes" - "encoding/json" - "fmt" - "io" - "net/http" - "strings" - "sync" +import "github.com/gravitl/netmaker/servercfg" - "github.com/gravitl/netmaker/servercfg" -) +var emqx Emqx -const already_exists = "ALREADY_EXISTS" - -type ( - emqxUser struct { - UserID string `json:"user_id"` - Password string `json:"password"` - Admin bool `json:"is_superuser"` - } - - emqxLogin struct { - Username string `json:"username"` - Password string `json:"password"` - } - - emqxLoginResponse struct { - License struct { - Edition string `json:"edition"` - } `json:"license"` - Token string `json:"token"` - Version string `json:"version"` - } - - aclRule struct { - Topic string `json:"topic"` - Permission string `json:"permission"` - Action string `json:"action"` - } - - aclObject struct { - Rules []aclRule `json:"rules"` - Username string `json:"username,omitempty"` - } -) - -func getEmqxAuthToken() (string, error) { - payload, err := json.Marshal(&emqxLogin{ - Username: servercfg.GetMqUserName(), - Password: servercfg.GetMqPassword(), - }) - if err != nil { - return "", err - } - resp, err := http.Post(servercfg.GetEmqxRestEndpoint()+"/api/v5/login", "application/json", bytes.NewReader(payload)) - if err != nil { - return "", err - } - msg, err := io.ReadAll(resp.Body) - if err != nil { - return "", err - } - if resp.StatusCode != http.StatusOK { - return "", fmt.Errorf("error during EMQX login %v", string(msg)) - } - var loginResp emqxLoginResponse - if err := json.Unmarshal(msg, &loginResp); err != nil { - return "", err - } - return loginResp.Token, nil +type Emqx interface { + GetType() servercfg.Emqxdeploy + CreateEmqxUser(username, password string, admin bool) error + CreateEmqxDefaultAuthenticator() error + CreateEmqxDefaultAuthorizer() error + CreateDefaultDenyRule() error + CreateHostACL(hostID, serverName string) error + AppendNodeUpdateACL(hostID, nodeNetwork, nodeID, serverName string) error + GetUserACL(username string) (*aclObject, error) + DeleteEmqxUser(username string) error } -// CreateEmqxUser - creates an EMQX user -func CreateEmqxUser(username, password string, admin bool) error { - token, err := getEmqxAuthToken() - if err != nil { - return err +func init() { + if servercfg.GetBrokerType() != servercfg.EmqxBrokerType { + return } - payload, err := json.Marshal(&emqxUser{ - UserID: username, - Password: password, - Admin: admin, - }) - if err != nil { - return err - } - req, err := http.NewRequest(http.MethodPost, servercfg.GetEmqxRestEndpoint()+"/api/v5/authentication/password_based:built_in_database/users", bytes.NewReader(payload)) - if err != nil { - return err - } - req.Header.Add("content-type", "application/json") - req.Header.Add("authorization", "Bearer "+token) - resp, err := (&http.Client{}).Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode >= 300 { - msg, err := io.ReadAll(resp.Body) - if err != nil { - return err + if servercfg.GetEmqxDeployType() == servercfg.EmqxCloudDeploy { + emqx = &EmqxCloud{ + URL: servercfg.GetEmqxRestEndpoint(), + AppID: servercfg.GetMqUserName(), + AppSecret: servercfg.GetMqPassword(), } - if !strings.Contains(string(msg), already_exists) { - return fmt.Errorf("error creating EMQX user %v", string(msg)) + } else { + emqx = &EmqxOnPrem{ + URL: servercfg.GetEmqxRestEndpoint(), + UserName: servercfg.GetMqUserName(), + Password: servercfg.GetMqPassword(), } } - return nil } -// DeleteEmqxUser - deletes an EMQX user -func DeleteEmqxUser(username string) error { - token, err := getEmqxAuthToken() - if err != nil { - return err - } - req, err := http.NewRequest(http.MethodDelete, servercfg.GetEmqxRestEndpoint()+"/api/v5/authentication/password_based:built_in_database/users/"+username, nil) - if err != nil { - return err - } - req.Header.Add("authorization", "Bearer "+token) - resp, err := (&http.Client{}).Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode >= 300 { - msg, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - return fmt.Errorf("error deleting EMQX user %v", string(msg)) - } - return nil -} - -// CreateEmqxDefaultAuthenticator - creates a default authenticator based on password and using EMQX's built in database as storage -func CreateEmqxDefaultAuthenticator() error { - token, err := getEmqxAuthToken() - if err != nil { - return err - } - payload, err := json.Marshal(&struct { - Mechanism string `json:"mechanism"` - Backend string `json:"backend"` - UserIDType string `json:"user_id_type"` - }{Mechanism: "password_based", Backend: "built_in_database", UserIDType: "username"}) - if err != nil { - return err - } - req, err := http.NewRequest(http.MethodPost, servercfg.GetEmqxRestEndpoint()+"/api/v5/authentication", bytes.NewReader(payload)) - if err != nil { - return err - } - req.Header.Add("content-type", "application/json") - req.Header.Add("authorization", "Bearer "+token) - resp, err := (&http.Client{}).Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusOK { - msg, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - return fmt.Errorf("error creating default EMQX authenticator %v", string(msg)) - } - return nil -} - -// CreateEmqxDefaultAuthorizer - creates a default ACL authorization mechanism based on the built in database -func CreateEmqxDefaultAuthorizer() error { - token, err := getEmqxAuthToken() - if err != nil { - return err - } - payload, err := json.Marshal(&struct { - Enable bool `json:"enable"` - Type string `json:"type"` - }{Enable: true, Type: "built_in_database"}) - if err != nil { - return err - } - req, err := http.NewRequest(http.MethodPost, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources", bytes.NewReader(payload)) - if err != nil { - return err - } - req.Header.Add("content-type", "application/json") - req.Header.Add("authorization", "Bearer "+token) - resp, err := (&http.Client{}).Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusNoContent { - msg, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - return fmt.Errorf("error creating default EMQX ACL authorization mechanism %v", string(msg)) - } - return nil -} - -// GetUserACL - returns ACL rules by username -func GetUserACL(username string) (*aclObject, error) { - token, err := getEmqxAuthToken() - if err != nil { - return nil, err - } - req, err := http.NewRequest(http.MethodGet, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/username/"+username, nil) - if err != nil { - return nil, err - } - req.Header.Add("content-type", "application/json") - req.Header.Add("authorization", "Bearer "+token) - resp, err := (&http.Client{}).Do(req) - if err != nil { - return nil, err - } - defer resp.Body.Close() - response, err := io.ReadAll(resp.Body) - if err != nil { - return nil, err - } - if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("error fetching ACL rules %v", string(response)) - } - body := new(aclObject) - if err := json.Unmarshal(response, body); err != nil { - return nil, err - } - return body, nil -} - -// CreateDefaultDenyRule - creates a rule to deny access to all topics for all users by default -// to allow user access to topics use the `mq.CreateUserAccessRule` function -func CreateDefaultDenyRule() error { - token, err := getEmqxAuthToken() - if err != nil { - return err - } - payload, err := json.Marshal(&aclObject{Rules: []aclRule{{Topic: "#", Permission: "deny", Action: "all"}}}) - if err != nil { - return err - } - req, err := http.NewRequest(http.MethodPost, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/all", bytes.NewReader(payload)) - if err != nil { - return err - } - req.Header.Add("content-type", "application/json") - req.Header.Add("authorization", "Bearer "+token) - resp, err := (&http.Client{}).Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusNoContent { - msg, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - return fmt.Errorf("error creating default ACL rules %v", string(msg)) - } - return nil -} - -// CreateHostACL - create host ACL rules -func CreateHostACL(hostID, serverName string) error { - token, err := getEmqxAuthToken() - if err != nil { - return err - } - payload, err := json.Marshal(&aclObject{ - Username: hostID, - Rules: []aclRule{ - { - Topic: fmt.Sprintf("peers/host/%s/%s", hostID, serverName), - Permission: "allow", - Action: "all", - }, - { - Topic: fmt.Sprintf("host/update/%s/%s", hostID, serverName), - Permission: "allow", - Action: "all", - }, - { - Topic: fmt.Sprintf("host/serverupdate/%s/%s", serverName, hostID), - Permission: "allow", - Action: "all", - }, - }, - }) - if err != nil { - return err - } - req, err := http.NewRequest(http.MethodPut, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/username/"+hostID, bytes.NewReader(payload)) - if err != nil { - return err - } - req.Header.Add("content-type", "application/json") - req.Header.Add("authorization", "Bearer "+token) - resp, err := (&http.Client{}).Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusNoContent { - msg, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - return fmt.Errorf("error adding ACL Rules for user %s Error: %v", hostID, string(msg)) - } - return nil -} - -// a lock required for preventing simultaneous updates to the same ACL object leading to overwriting each other -// might occur when multiple nodes belonging to the same host are created at the same time -var nodeAclMux sync.Mutex - -// AppendNodeUpdateACL - adds ACL rule for subscribing to node updates for a node ID -func AppendNodeUpdateACL(hostID, nodeNetwork, nodeID, serverName string) error { - nodeAclMux.Lock() - defer nodeAclMux.Unlock() - token, err := getEmqxAuthToken() - if err != nil { - return err - } - aclObject, err := GetUserACL(hostID) - if err != nil { - return err - } - aclObject.Rules = append(aclObject.Rules, []aclRule{ - { - Topic: fmt.Sprintf("node/update/%s/%s", nodeNetwork, nodeID), - Permission: "allow", - Action: "subscribe", - }, - { - Topic: fmt.Sprintf("ping/%s/%s", serverName, nodeID), - Permission: "allow", - Action: "all", - }, - { - Topic: fmt.Sprintf("update/%s/%s", serverName, nodeID), - Permission: "allow", - Action: "all", - }, - { - Topic: fmt.Sprintf("signal/%s/%s", serverName, nodeID), - Permission: "allow", - Action: "all", - }, - { - Topic: fmt.Sprintf("metrics/%s/%s", serverName, nodeID), - Permission: "allow", - Action: "all", - }, - }...) - payload, err := json.Marshal(aclObject) - if err != nil { - return err - } - req, err := http.NewRequest(http.MethodPut, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/username/"+hostID, bytes.NewReader(payload)) - if err != nil { - return err - } - req.Header.Add("content-type", "application/json") - req.Header.Add("authorization", "Bearer "+token) - resp, err := (&http.Client{}).Do(req) - if err != nil { - return err - } - defer resp.Body.Close() - if resp.StatusCode != http.StatusNoContent { - msg, err := io.ReadAll(resp.Body) - if err != nil { - return err - } - return fmt.Errorf("error adding ACL Rules for user %s Error: %v", hostID, string(msg)) - } - return nil +func GetEmqxHandler() Emqx { + return emqx } diff --git a/mq/emqx_cloud.go b/mq/emqx_cloud.go new file mode 100644 index 00000000..60b0b2b4 --- /dev/null +++ b/mq/emqx_cloud.go @@ -0,0 +1,72 @@ +package mq + +import ( + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gravitl/netmaker/servercfg" +) + +type EmqxCloud struct { + URL string + AppID string + AppSecret string +} + +type userCreateReq struct { + UserName string `json:"username"` + Password string `json:"password"` +} + +func (e *EmqxCloud) GetType() servercfg.Emqxdeploy { return servercfg.EmqxCloudDeploy } + +func (e *EmqxCloud) CreateEmqxUser(username, pass string, admin bool) error { + + payload := userCreateReq{ + UserName: username, + Password: pass, + } + data, _ := json.Marshal(payload) + client := &http.Client{} + req, err := http.NewRequest(http.MethodPost, e.URL, strings.NewReader(string(data))) + if err != nil { + fmt.Println(err) + return err + } + req.SetBasicAuth(e.AppID, e.AppSecret) + req.Header.Add("Content-Type", "application/json") + + res, err := client.Do(req) + if err != nil { + fmt.Println(err) + return err + } + defer res.Body.Close() + + body, err := io.ReadAll(res.Body) + if err != nil { + fmt.Println(err) + return err + } + fmt.Println(string(body)) + return nil +} + +func (e *EmqxCloud) CreateEmqxDefaultAuthenticator() error { return nil } + +func (e *EmqxCloud) CreateEmqxDefaultAuthorizer() error { return nil } + +func (e *EmqxCloud) CreateDefaultDenyRule() error { return nil } + +func (e *EmqxCloud) CreateHostACL(hostID, serverName string) error { return nil } + +func (e *EmqxCloud) AppendNodeUpdateACL(hostID, nodeNetwork, nodeID, serverName string) error { + return nil +} + +func (e *EmqxCloud) GetUserACL(username string) (*aclObject, error) { return nil, nil } + +func (e *EmqxCloud) DeleteEmqxUser(username string) error { return nil } diff --git a/mq/emqx_on_prem.go b/mq/emqx_on_prem.go new file mode 100644 index 00000000..73ef9fff --- /dev/null +++ b/mq/emqx_on_prem.go @@ -0,0 +1,394 @@ +package mq + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + "sync" + + "github.com/gravitl/netmaker/servercfg" +) + +type EmqxOnPrem struct { + URL string + UserName string + Password string +} + +const already_exists = "ALREADY_EXISTS" + +type ( + emqxUser struct { + UserID string `json:"user_id"` + Password string `json:"password"` + Admin bool `json:"is_superuser"` + } + + emqxLogin struct { + Username string `json:"username"` + Password string `json:"password"` + } + + emqxLoginResponse struct { + License struct { + Edition string `json:"edition"` + } `json:"license"` + Token string `json:"token"` + Version string `json:"version"` + } + + aclRule struct { + Topic string `json:"topic"` + Permission string `json:"permission"` + Action string `json:"action"` + } + + aclObject struct { + Rules []aclRule `json:"rules"` + Username string `json:"username,omitempty"` + } +) + +func getEmqxAuthToken() (string, error) { + payload, err := json.Marshal(&emqxLogin{ + Username: servercfg.GetMqUserName(), + Password: servercfg.GetMqPassword(), + }) + if err != nil { + return "", err + } + resp, err := http.Post(servercfg.GetEmqxRestEndpoint()+"/api/v5/login", "application/json", bytes.NewReader(payload)) + if err != nil { + return "", err + } + msg, err := io.ReadAll(resp.Body) + if err != nil { + return "", err + } + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("error during EMQX login %v", string(msg)) + } + var loginResp emqxLoginResponse + if err := json.Unmarshal(msg, &loginResp); err != nil { + return "", err + } + return loginResp.Token, nil +} + +func (e *EmqxOnPrem) GetType() servercfg.Emqxdeploy { return servercfg.EmqxOnPremDeploy } + +// CreateEmqxUser - creates an EMQX user +func (e *EmqxOnPrem) CreateEmqxUser(username, password string, admin bool) error { + token, err := getEmqxAuthToken() + if err != nil { + return err + } + payload, err := json.Marshal(&emqxUser{ + UserID: username, + Password: password, + Admin: admin, + }) + if err != nil { + return err + } + req, err := http.NewRequest(http.MethodPost, servercfg.GetEmqxRestEndpoint()+"/api/v5/authentication/password_based:built_in_database/users", bytes.NewReader(payload)) + if err != nil { + return err + } + req.Header.Add("content-type", "application/json") + req.Header.Add("authorization", "Bearer "+token) + resp, err := (&http.Client{}).Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + msg, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + if !strings.Contains(string(msg), already_exists) { + return fmt.Errorf("error creating EMQX user %v", string(msg)) + } + } + return nil +} + +// DeleteEmqxUser - deletes an EMQX user +func (e *EmqxOnPrem) DeleteEmqxUser(username string) error { + token, err := getEmqxAuthToken() + if err != nil { + return err + } + req, err := http.NewRequest(http.MethodDelete, servercfg.GetEmqxRestEndpoint()+"/api/v5/authentication/password_based:built_in_database/users/"+username, nil) + if err != nil { + return err + } + req.Header.Add("authorization", "Bearer "+token) + resp, err := (&http.Client{}).Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode >= 300 { + msg, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + return fmt.Errorf("error deleting EMQX user %v", string(msg)) + } + return nil +} + +// CreateEmqxDefaultAuthenticator - creates a default authenticator based on password and using EMQX's built in database as storage +func (e *EmqxOnPrem) CreateEmqxDefaultAuthenticator() error { + token, err := getEmqxAuthToken() + if err != nil { + return err + } + payload, err := json.Marshal(&struct { + Mechanism string `json:"mechanism"` + Backend string `json:"backend"` + UserIDType string `json:"user_id_type"` + }{Mechanism: "password_based", Backend: "built_in_database", UserIDType: "username"}) + if err != nil { + return err + } + req, err := http.NewRequest(http.MethodPost, servercfg.GetEmqxRestEndpoint()+"/api/v5/authentication", bytes.NewReader(payload)) + if err != nil { + return err + } + req.Header.Add("content-type", "application/json") + req.Header.Add("authorization", "Bearer "+token) + resp, err := (&http.Client{}).Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { + msg, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + return fmt.Errorf("error creating default EMQX authenticator %v", string(msg)) + } + return nil +} + +// CreateEmqxDefaultAuthorizer - creates a default ACL authorization mechanism based on the built in database +func (e *EmqxOnPrem) CreateEmqxDefaultAuthorizer() error { + token, err := getEmqxAuthToken() + if err != nil { + return err + } + payload, err := json.Marshal(&struct { + Enable bool `json:"enable"` + Type string `json:"type"` + }{Enable: true, Type: "built_in_database"}) + if err != nil { + return err + } + req, err := http.NewRequest(http.MethodPost, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources", bytes.NewReader(payload)) + if err != nil { + return err + } + req.Header.Add("content-type", "application/json") + req.Header.Add("authorization", "Bearer "+token) + resp, err := (&http.Client{}).Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { + msg, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + return fmt.Errorf("error creating default EMQX ACL authorization mechanism %v", string(msg)) + } + return nil +} + +// GetUserACL - returns ACL rules by username +func (e *EmqxOnPrem) GetUserACL(username string) (*aclObject, error) { + token, err := getEmqxAuthToken() + if err != nil { + return nil, err + } + req, err := http.NewRequest(http.MethodGet, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/username/"+username, nil) + if err != nil { + return nil, err + } + req.Header.Add("content-type", "application/json") + req.Header.Add("authorization", "Bearer "+token) + resp, err := (&http.Client{}).Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + response, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("error fetching ACL rules %v", string(response)) + } + body := new(aclObject) + if err := json.Unmarshal(response, body); err != nil { + return nil, err + } + return body, nil +} + +// CreateDefaultDenyRule - creates a rule to deny access to all topics for all users by default +// to allow user access to topics use the `mq.CreateUserAccessRule` function +func (e *EmqxOnPrem) CreateDefaultDenyRule() error { + token, err := getEmqxAuthToken() + if err != nil { + return err + } + payload, err := json.Marshal(&aclObject{Rules: []aclRule{{Topic: "#", Permission: "deny", Action: "all"}}}) + if err != nil { + return err + } + req, err := http.NewRequest(http.MethodPost, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/all", bytes.NewReader(payload)) + if err != nil { + return err + } + req.Header.Add("content-type", "application/json") + req.Header.Add("authorization", "Bearer "+token) + resp, err := (&http.Client{}).Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { + msg, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + return fmt.Errorf("error creating default ACL rules %v", string(msg)) + } + return nil +} + +// CreateHostACL - create host ACL rules +func (e *EmqxOnPrem) CreateHostACL(hostID, serverName string) error { + token, err := getEmqxAuthToken() + if err != nil { + return err + } + payload, err := json.Marshal(&aclObject{ + Username: hostID, + Rules: []aclRule{ + { + Topic: fmt.Sprintf("peers/host/%s/%s", hostID, serverName), + Permission: "allow", + Action: "all", + }, + { + Topic: fmt.Sprintf("host/update/%s/%s", hostID, serverName), + Permission: "allow", + Action: "all", + }, + { + Topic: fmt.Sprintf("host/serverupdate/%s/%s", serverName, hostID), + Permission: "allow", + Action: "all", + }, + }, + }) + if err != nil { + return err + } + req, err := http.NewRequest(http.MethodPut, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/username/"+hostID, bytes.NewReader(payload)) + if err != nil { + return err + } + req.Header.Add("content-type", "application/json") + req.Header.Add("authorization", "Bearer "+token) + resp, err := (&http.Client{}).Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { + msg, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + return fmt.Errorf("error adding ACL Rules for user %s Error: %v", hostID, string(msg)) + } + return nil +} + +// a lock required for preventing simultaneous updates to the same ACL object leading to overwriting each other +// might occur when multiple nodes belonging to the same host are created at the same time +var nodeAclMux sync.Mutex + +// AppendNodeUpdateACL - adds ACL rule for subscribing to node updates for a node ID +func (e *EmqxOnPrem) AppendNodeUpdateACL(hostID, nodeNetwork, nodeID, serverName string) error { + nodeAclMux.Lock() + defer nodeAclMux.Unlock() + token, err := getEmqxAuthToken() + if err != nil { + return err + } + aclObject, err := emqx.GetUserACL(hostID) + if err != nil { + return err + } + aclObject.Rules = append(aclObject.Rules, []aclRule{ + { + Topic: fmt.Sprintf("node/update/%s/%s", nodeNetwork, nodeID), + Permission: "allow", + Action: "subscribe", + }, + { + Topic: fmt.Sprintf("ping/%s/%s", serverName, nodeID), + Permission: "allow", + Action: "all", + }, + { + Topic: fmt.Sprintf("update/%s/%s", serverName, nodeID), + Permission: "allow", + Action: "all", + }, + { + Topic: fmt.Sprintf("signal/%s/%s", serverName, nodeID), + Permission: "allow", + Action: "all", + }, + { + Topic: fmt.Sprintf("metrics/%s/%s", serverName, nodeID), + Permission: "allow", + Action: "all", + }, + }...) + payload, err := json.Marshal(aclObject) + if err != nil { + return err + } + req, err := http.NewRequest(http.MethodPut, servercfg.GetEmqxRestEndpoint()+"/api/v5/authorization/sources/built_in_database/username/"+hostID, bytes.NewReader(payload)) + if err != nil { + return err + } + req.Header.Add("content-type", "application/json") + req.Header.Add("authorization", "Bearer "+token) + resp, err := (&http.Client{}).Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + if resp.StatusCode != http.StatusNoContent { + msg, err := io.ReadAll(resp.Body) + if err != nil { + return err + } + return fmt.Errorf("error adding ACL Rules for user %s Error: %v", hostID, string(msg)) + } + return nil +} diff --git a/mq/handlers.go b/mq/handlers.go index 469d0376..528e5d07 100644 --- a/mq/handlers.go +++ b/mq/handlers.go @@ -114,7 +114,7 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) { return } else { if servercfg.GetBrokerType() == servercfg.EmqxBrokerType { - if err = AppendNodeUpdateACL(hu.Host.ID.String(), hu.Node.Network, hu.Node.ID.String(), servercfg.GetServer()); err != nil { + if err = emqx.AppendNodeUpdateACL(hu.Host.ID.String(), hu.Node.Network, hu.Node.ID.String(), servercfg.GetServer()); err != nil { slog.Error("failed to add ACLs for EMQX node", "error", err) return } @@ -143,7 +143,7 @@ func UpdateHost(client mqtt.Client, msg mqtt.Message) { case models.DeleteHost: if servercfg.GetBrokerType() == servercfg.EmqxBrokerType { // delete EMQX credentials for host - if err := DeleteEmqxUser(currentHost.ID.String()); err != nil { + if err := emqx.DeleteEmqxUser(currentHost.ID.String()); err != nil { slog.Error("failed to remove host credentials from EMQX", "id", currentHost.ID, "error", err) } } diff --git a/mq/mq.go b/mq/mq.go index 9be06871..ea61ac11 100644 --- a/mq/mq.go +++ b/mq/mq.go @@ -43,19 +43,19 @@ func SetupMQTT() { if servercfg.GetBrokerType() == servercfg.EmqxBrokerType { time.Sleep(10 * time.Second) // wait for the REST endpoint to be ready // setup authenticator and create admin user - if err := CreateEmqxDefaultAuthenticator(); err != nil { + if err := emqx.CreateEmqxDefaultAuthenticator(); err != nil { logger.Log(0, err.Error()) } - DeleteEmqxUser(servercfg.GetMqUserName()) - if err := CreateEmqxUser(servercfg.GetMqUserName(), servercfg.GetMqPassword(), true); err != nil { + emqx.DeleteEmqxUser(servercfg.GetMqUserName()) + if err := emqx.CreateEmqxUser(servercfg.GetMqUserName(), servercfg.GetMqPassword(), true); err != nil { log.Fatal(err) } // create an ACL authorization source for the built in EMQX MNESIA database - if err := CreateEmqxDefaultAuthorizer(); err != nil { + if err := emqx.CreateEmqxDefaultAuthorizer(); err != nil { logger.Log(0, err.Error()) } // create a default deny ACL to all topics for all users - if err := CreateDefaultDenyRule(); err != nil { + if err := emqx.CreateDefaultDenyRule(); err != nil { log.Fatal(err) } } diff --git a/servercfg/serverconf.go b/servercfg/serverconf.go index 63a9f0cc..33214fba 100644 --- a/servercfg/serverconf.go +++ b/servercfg/serverconf.go @@ -17,10 +17,15 @@ import ( // EmqxBrokerType denotes the broker type for EMQX MQTT const EmqxBrokerType = "emqx" +// Emqxdeploy - emqx deploy type +type Emqxdeploy string + var ( Version = "dev" IsPro = false ErrLicenseValidation error + EmqxCloudDeploy Emqxdeploy = "cloud" + EmqxOnPremDeploy Emqxdeploy = "on-prem" ) // SetHost - sets the host ip @@ -674,3 +679,12 @@ func GetEnvironment() string { } return "" } + +// GetEmqxDeployType - fetches emqx deploy type this server uses +func GetEmqxDeployType() (deployType Emqxdeploy) { + deployType = EmqxOnPremDeploy + if os.Getenv("EMQX_DEPLOY_TYPE") == string(EmqxCloudDeploy) { + deployType = EmqxCloudDeploy + } + return +}