mirror of
https://github.com/gravitl/netmaker.git
synced 2025-09-03 19:54:22 +08:00
commit
8b876c17b9
22 changed files with 1123 additions and 168 deletions
|
@ -1,19 +1,25 @@
|
|||
package controller
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/mq"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
func dnsHandlers(r *mux.Router) {
|
||||
|
@ -34,6 +40,274 @@ func dnsHandlers(r *mux.Router) {
|
|||
Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/dns/{network}/{domain}", logic.SecurityCheck(true, http.HandlerFunc(deleteDNS))).
|
||||
Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/v1/nameserver", logic.SecurityCheck(true, http.HandlerFunc(createNs))).Methods(http.MethodPost)
|
||||
r.HandleFunc("/api/v1/nameserver", logic.SecurityCheck(true, http.HandlerFunc(listNs))).Methods(http.MethodGet)
|
||||
r.HandleFunc("/api/v1/nameserver", logic.SecurityCheck(true, http.HandlerFunc(updateNs))).Methods(http.MethodPut)
|
||||
r.HandleFunc("/api/v1/nameserver", logic.SecurityCheck(true, http.HandlerFunc(deleteNs))).Methods(http.MethodDelete)
|
||||
r.HandleFunc("/api/v1/nameserver/global", logic.SecurityCheck(true, http.HandlerFunc(getGlobalNs))).Methods(http.MethodGet)
|
||||
}
|
||||
|
||||
// @Summary List Global Nameservers
|
||||
// @Router /api/v1/nameserver/global [get]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param query network string
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func getGlobalNs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
logic.ReturnSuccessResponseWithJson(w, r, logic.GlobalNsList, "fetched nameservers")
|
||||
}
|
||||
|
||||
// @Summary Create Nameserver
|
||||
// @Router /api/v1/nameserver [post]
|
||||
// @Tags DNS
|
||||
// @Accept json
|
||||
// @Param body body models.NameserverReq
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func createNs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var req schema.Nameserver
|
||||
err := json.NewDecoder(r.Body).Decode(&req)
|
||||
if err != nil {
|
||||
logger.Log(0, "error decoding request body: ",
|
||||
err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
if err := logic.ValidateNameserverReq(req); err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
if req.Tags == nil {
|
||||
req.Tags = make(datatypes.JSONMap)
|
||||
}
|
||||
if gNs, ok := logic.GlobalNsList[req.Name]; ok {
|
||||
req.Servers = gNs.IPs
|
||||
}
|
||||
if !servercfg.IsPro {
|
||||
req.Tags = datatypes.JSONMap{
|
||||
"*": struct{}{},
|
||||
}
|
||||
}
|
||||
if req.MatchAll {
|
||||
req.MatchDomains = []string{"."}
|
||||
}
|
||||
ns := schema.Nameserver{
|
||||
ID: uuid.New().String(),
|
||||
Name: req.Name,
|
||||
NetworkID: req.NetworkID,
|
||||
Description: req.Description,
|
||||
MatchAll: req.MatchAll,
|
||||
MatchDomains: req.MatchDomains,
|
||||
Servers: req.Servers,
|
||||
Tags: req.Tags,
|
||||
Status: true,
|
||||
CreatedBy: r.Header.Get("user"),
|
||||
CreatedAt: time.Now().UTC(),
|
||||
}
|
||||
|
||||
err = ns.Create(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
logic.FormatError(errors.New("error creating nameserver "+err.Error()), logic.Internal),
|
||||
)
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Create,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: ns.ID,
|
||||
Name: ns.Name,
|
||||
Type: models.NameserverSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(ns.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
})
|
||||
|
||||
go mq.PublishPeerUpdate(false)
|
||||
logic.ReturnSuccessResponseWithJson(w, r, ns, "created nameserver")
|
||||
}
|
||||
|
||||
// @Summary List Nameservers
|
||||
// @Router /api/v1/nameserver [get]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param query network string
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func listNs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
network := r.URL.Query().Get("network")
|
||||
if network == "" {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("network is required"), "badrequest"))
|
||||
return
|
||||
}
|
||||
ns := schema.Nameserver{NetworkID: network}
|
||||
list, err := ns.ListByNetwork(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
logic.FormatError(errors.New("error listing nameservers "+err.Error()), "internal"),
|
||||
)
|
||||
return
|
||||
}
|
||||
logic.ReturnSuccessResponseWithJson(w, r, list, "fetched nameservers")
|
||||
}
|
||||
|
||||
// @Summary Update Nameserver
|
||||
// @Router /api/v1/nameserver [put]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param body body models.NameserverReq
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func updateNs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var updateNs schema.Nameserver
|
||||
err := json.NewDecoder(r.Body).Decode(&updateNs)
|
||||
if err != nil {
|
||||
logger.Log(0, "error decoding request body: ",
|
||||
err.Error())
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
|
||||
if err := logic.ValidateNameserverReq(updateNs); err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
if updateNs.Tags == nil {
|
||||
updateNs.Tags = make(datatypes.JSONMap)
|
||||
}
|
||||
|
||||
ns := schema.Nameserver{ID: updateNs.ID}
|
||||
err = ns.Get(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "badrequest"))
|
||||
return
|
||||
}
|
||||
var updateStatus bool
|
||||
var updateMatchAll bool
|
||||
if updateNs.Status != ns.Status {
|
||||
updateStatus = true
|
||||
}
|
||||
if updateNs.MatchAll != ns.MatchAll {
|
||||
updateMatchAll = true
|
||||
}
|
||||
event := &models.Event{
|
||||
Action: models.Update,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: ns.ID,
|
||||
Name: updateNs.Name,
|
||||
Type: models.NameserverSub,
|
||||
},
|
||||
Diff: models.Diff{
|
||||
Old: ns,
|
||||
New: updateNs,
|
||||
},
|
||||
NetworkID: models.NetworkID(ns.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
}
|
||||
ns.Servers = updateNs.Servers
|
||||
ns.Tags = updateNs.Tags
|
||||
ns.MatchDomains = updateNs.MatchDomains
|
||||
ns.MatchAll = updateNs.MatchAll
|
||||
ns.Description = updateNs.Description
|
||||
ns.Name = updateNs.Name
|
||||
ns.Status = updateNs.Status
|
||||
ns.UpdatedAt = time.Now().UTC()
|
||||
|
||||
err = ns.Update(db.WithContext(context.TODO()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(
|
||||
w,
|
||||
r,
|
||||
logic.FormatError(errors.New("error creating egress resource"+err.Error()), "internal"),
|
||||
)
|
||||
return
|
||||
}
|
||||
if updateStatus {
|
||||
ns.UpdateStatus(db.WithContext(context.TODO()))
|
||||
}
|
||||
if updateMatchAll {
|
||||
ns.UpdateMatchAll(db.WithContext(context.TODO()))
|
||||
}
|
||||
logic.LogEvent(event)
|
||||
go mq.PublishPeerUpdate(false)
|
||||
logic.ReturnSuccessResponseWithJson(w, r, ns, "updated nameserver")
|
||||
}
|
||||
|
||||
// @Summary Delete Nameserver Resource
|
||||
// @Router /api/v1/nameserver [delete]
|
||||
// @Tags Auth
|
||||
// @Accept json
|
||||
// @Param body body models.Egress
|
||||
// @Success 200 {object} models.SuccessResponse
|
||||
// @Failure 400 {object} models.ErrorResponse
|
||||
// @Failure 401 {object} models.ErrorResponse
|
||||
// @Failure 500 {object} models.ErrorResponse
|
||||
func deleteNs(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
id := r.URL.Query().Get("id")
|
||||
if id == "" {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(errors.New("id is required"), "badrequest"))
|
||||
return
|
||||
}
|
||||
ns := schema.Nameserver{ID: id}
|
||||
err := ns.Get(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.BadReq))
|
||||
return
|
||||
}
|
||||
err = ns.Delete(db.WithContext(r.Context()))
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, logic.Internal))
|
||||
return
|
||||
}
|
||||
logic.LogEvent(&models.Event{
|
||||
Action: models.Delete,
|
||||
Source: models.Subject{
|
||||
ID: r.Header.Get("user"),
|
||||
Name: r.Header.Get("user"),
|
||||
Type: models.UserSub,
|
||||
},
|
||||
TriggeredBy: r.Header.Get("user"),
|
||||
Target: models.Subject{
|
||||
ID: ns.ID,
|
||||
Name: ns.Name,
|
||||
Type: models.NameserverSub,
|
||||
},
|
||||
NetworkID: models.NetworkID(ns.NetworkID),
|
||||
Origin: models.Dashboard,
|
||||
})
|
||||
|
||||
go mq.PublishPeerUpdate(false)
|
||||
logic.ReturnSuccessResponseWithJson(w, r, nil, "deleted nameserver resource")
|
||||
}
|
||||
|
||||
// @Summary Gets node DNS entries associated with a network
|
||||
|
|
|
@ -133,6 +133,12 @@ func getExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
gwNode, err := logic.GetNodeByID(client.IngressGatewayID)
|
||||
if err != nil {
|
||||
logic.ReturnErrorResponse(w, r, logic.FormatError(err, "internal"))
|
||||
return
|
||||
}
|
||||
logic.SetDNSOnWgConfig(&gwNode, &client)
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
json.NewEncoder(w).Encode(client)
|
||||
|
@ -288,39 +294,11 @@ func getExtClientConf(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
logic.SetDNSOnWgConfig(&gwnode, &client)
|
||||
defaultDNS := ""
|
||||
if client.DNS != "" {
|
||||
defaultDNS = "DNS = " + client.DNS
|
||||
} else if gwnode.IngressDNS != "" {
|
||||
defaultDNS = "DNS = " + gwnode.IngressDNS
|
||||
}
|
||||
if client.DNS == "" {
|
||||
if len(network.NameServers) > 0 {
|
||||
if defaultDNS == "" {
|
||||
defaultDNS = "DNS = " + strings.Join(network.NameServers, ",")
|
||||
} else {
|
||||
defaultDNS += "," + strings.Join(network.NameServers, ",")
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
// if servercfg.GetManageDNS() {
|
||||
// if gwnode.Address6.IP != nil {
|
||||
// if defaultDNS == "" {
|
||||
// defaultDNS = "DNS = " + gwnode.Address6.IP.String()
|
||||
// } else {
|
||||
// defaultDNS = defaultDNS + ", " + gwnode.Address6.IP.String()
|
||||
// }
|
||||
// }
|
||||
// if gwnode.Address.IP != nil {
|
||||
// if defaultDNS == "" {
|
||||
// defaultDNS = "DNS = " + gwnode.Address.IP.String()
|
||||
// } else {
|
||||
// defaultDNS = defaultDNS + ", " + gwnode.Address.IP.String()
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
defaultMTU := 1420
|
||||
if host.MTU != 0 {
|
||||
|
@ -745,18 +723,10 @@ func createExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
extclient.Tags = make(map[models.TagID]struct{})
|
||||
// extclient.Tags[models.TagID(fmt.Sprintf("%s.%s", extclient.Network,
|
||||
// models.RemoteAccessTagName))] = struct{}{}
|
||||
// set extclient dns to ingressdns if extclient dns is not explicitly set
|
||||
if (extclient.DNS == "") && (node.IngressDNS != "") {
|
||||
network, _ := logic.GetNetwork(node.Network)
|
||||
dns := node.IngressDNS
|
||||
if len(network.NameServers) > 0 {
|
||||
if dns == "" {
|
||||
dns = strings.Join(network.NameServers, ",")
|
||||
} else {
|
||||
dns += "," + strings.Join(network.NameServers, ",")
|
||||
}
|
||||
|
||||
}
|
||||
// set extclient dns to ingressdns if extclient dns is not explicitly
|
||||
gwDNS := logic.GetGwDNS(&node)
|
||||
if (extclient.DNS == "") && (gwDNS != "") {
|
||||
dns := gwDNS
|
||||
extclient.DNS = dns
|
||||
}
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
|
@ -868,7 +838,6 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
|
||||
var update models.CustomExtClient
|
||||
//var oldExtClient models.ExtClient
|
||||
var sendPeerUpdate bool
|
||||
var replacePeers bool
|
||||
err := json.NewDecoder(r.Body).Decode(&update)
|
||||
if err != nil {
|
||||
|
@ -917,19 +886,11 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
var changedID = update.ClientID != oldExtClient.ClientID
|
||||
|
||||
if !reflect.DeepEqual(update.DeniedACLs, oldExtClient.DeniedACLs) {
|
||||
sendPeerUpdate = true
|
||||
logic.SetClientACLs(&oldExtClient, update.DeniedACLs)
|
||||
}
|
||||
if !logic.IsSlicesEqual(update.ExtraAllowedIPs, oldExtClient.ExtraAllowedIPs) {
|
||||
sendPeerUpdate = true
|
||||
}
|
||||
|
||||
if update.Enabled != oldExtClient.Enabled {
|
||||
sendPeerUpdate = true
|
||||
}
|
||||
if update.PublicKey != oldExtClient.PublicKey {
|
||||
//remove old peer entry
|
||||
sendPeerUpdate = true
|
||||
replacePeers = true
|
||||
}
|
||||
if update.RemoteAccessClientID != "" && update.Location == "" {
|
||||
|
@ -974,45 +935,12 @@ func updateExtClient(w http.ResponseWriter, r *http.Request) {
|
|||
if changedID && servercfg.IsDNSMode() {
|
||||
logic.SetDNS()
|
||||
}
|
||||
if replacePeers {
|
||||
if replacePeers || !update.Enabled {
|
||||
if err := mq.PublishDeletedClientPeerUpdate(&oldExtClient); err != nil {
|
||||
slog.Error("error deleting old ext peers", "error", err.Error())
|
||||
}
|
||||
}
|
||||
if sendPeerUpdate { // need to send a peer update to the ingress node as enablement of one of it's clients has changed
|
||||
ingressNode, err := logic.GetNodeByID(newclient.IngressGatewayID)
|
||||
if err == nil {
|
||||
if err = mq.PublishPeerUpdate(false); err != nil {
|
||||
logger.Log(
|
||||
1,
|
||||
"error setting ext peers on",
|
||||
ingressNode.ID.String(),
|
||||
":",
|
||||
err.Error(),
|
||||
)
|
||||
}
|
||||
}
|
||||
if !update.Enabled {
|
||||
ingressHost, err := logic.GetHost(ingressNode.HostID.String())
|
||||
if err != nil {
|
||||
slog.Error(
|
||||
"Failed to get ingress host",
|
||||
"node",
|
||||
ingressNode.ID.String(),
|
||||
"error",
|
||||
err,
|
||||
)
|
||||
return
|
||||
}
|
||||
nodes, err := logic.GetAllNodes()
|
||||
if err != nil {
|
||||
slog.Error("Failed to get nodes", "error", err)
|
||||
return
|
||||
}
|
||||
go mq.PublishSingleHostPeerUpdate(ingressHost, nodes, nil, []models.ExtClient{oldExtClient}, false, nil)
|
||||
}
|
||||
}
|
||||
|
||||
mq.PublishPeerUpdate(false)
|
||||
}()
|
||||
|
||||
}
|
||||
|
|
|
@ -245,6 +245,7 @@ func pull(w http.ResponseWriter, r *http.Request) {
|
|||
DefaultGwIp: hPU.DefaultGwIp,
|
||||
IsInternetGw: hPU.IsInternetGw,
|
||||
EndpointDetection: logic.IsEndpointDetectionEnabled(),
|
||||
DnsNameservers: hPU.DnsNameservers,
|
||||
}
|
||||
|
||||
logger.Log(1, hostID, "completed a pull")
|
||||
|
|
251
logic/dns.go
251
logic/dns.go
|
@ -1,20 +1,64 @@
|
|||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
validator "github.com/go-playground/validator/v10"
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
"github.com/gravitl/netmaker/servercfg"
|
||||
"github.com/txn2/txeh"
|
||||
)
|
||||
|
||||
var GetNameserversForNode = getNameserversForNode
|
||||
var GetNameserversForHost = getNameserversForHost
|
||||
var ValidateNameserverReq = validateNameserverReq
|
||||
|
||||
type GlobalNs struct {
|
||||
ID string `json:"id"`
|
||||
IPs []string `json:"ips"`
|
||||
}
|
||||
|
||||
var GlobalNsList = map[string]GlobalNs{
|
||||
"Google": {
|
||||
ID: "Google",
|
||||
IPs: []string{
|
||||
"8.8.8.8",
|
||||
"8.8.4.4",
|
||||
"2001:4860:4860::8888",
|
||||
"2001:4860:4860::8844",
|
||||
},
|
||||
},
|
||||
"Cloudflare": {
|
||||
ID: "Cloudflare",
|
||||
IPs: []string{
|
||||
"1.1.1.1",
|
||||
"1.0.0.1",
|
||||
"2606:4700:4700::1111",
|
||||
"2606:4700:4700::1001",
|
||||
},
|
||||
},
|
||||
"Quad9": {
|
||||
ID: "Quad9",
|
||||
IPs: []string{
|
||||
"9.9.9.9",
|
||||
"149.112.112.112",
|
||||
"2620:fe::fe",
|
||||
"2620:fe::9",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// SetDNS - sets the dns on file
|
||||
func SetDNS() error {
|
||||
hostfile, err := txeh.NewHosts(&txeh.HostsConfig{})
|
||||
|
@ -133,6 +177,34 @@ func GetNodeDNS(network string) ([]models.DNSEntry, error) {
|
|||
return dns, nil
|
||||
}
|
||||
|
||||
func GetGwDNS(node *models.Node) string {
|
||||
if !servercfg.GetManageDNS() {
|
||||
return ""
|
||||
}
|
||||
h, err := GetHost(node.HostID.String())
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
if h.DNS != "yes" {
|
||||
return ""
|
||||
}
|
||||
dns := []string{}
|
||||
if node.Address.IP != nil {
|
||||
dns = append(dns, node.Address.IP.String())
|
||||
}
|
||||
if node.Address6.IP != nil {
|
||||
dns = append(dns, node.Address6.IP.String())
|
||||
}
|
||||
return strings.Join(dns, ",")
|
||||
|
||||
}
|
||||
|
||||
func SetDNSOnWgConfig(gwNode *models.Node, extclient *models.ExtClient) {
|
||||
if extclient.DNS == "" {
|
||||
extclient.DNS = GetGwDNS(gwNode)
|
||||
}
|
||||
}
|
||||
|
||||
// GetCustomDNS - gets the custom DNS of a network
|
||||
func GetCustomDNS(network string) ([]models.DNSEntry, error) {
|
||||
|
||||
|
@ -325,3 +397,182 @@ func CreateDNS(entry models.DNSEntry) (models.DNSEntry, error) {
|
|||
err = database.Insert(k, string(data), database.DNS_TABLE_NAME)
|
||||
return entry, err
|
||||
}
|
||||
|
||||
func validateNameserverReq(ns schema.Nameserver) error {
|
||||
if ns.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
if ns.NetworkID == "" {
|
||||
return errors.New("network is required")
|
||||
}
|
||||
if len(ns.Servers) == 0 {
|
||||
return errors.New("atleast one nameserver should be specified")
|
||||
}
|
||||
if !ns.MatchAll && len(ns.MatchDomains) == 0 {
|
||||
return errors.New("atleast one match domain is required")
|
||||
}
|
||||
if !ns.MatchAll {
|
||||
for _, matchDomain := range ns.MatchDomains {
|
||||
if !IsValidMatchDomain(matchDomain) {
|
||||
return errors.New("invalid match domain")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
|
||||
ns := &schema.Nameserver{
|
||||
NetworkID: node.Network,
|
||||
}
|
||||
nsLi, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
|
||||
for _, nsI := range nsLi {
|
||||
if !nsI.Status {
|
||||
continue
|
||||
}
|
||||
_, all := nsI.Tags["*"]
|
||||
if all {
|
||||
for _, matchDomain := range nsI.MatchDomains {
|
||||
returnNsLi = append(returnNsLi, models.Nameserver{
|
||||
IPs: nsI.Servers,
|
||||
MatchDomain: matchDomain,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := nsI.Nodes[node.ID.String()]; ok {
|
||||
for _, matchDomain := range nsI.MatchDomains {
|
||||
returnNsLi = append(returnNsLi, models.Nameserver{
|
||||
IPs: nsI.Servers,
|
||||
MatchDomain: matchDomain,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
if node.IsInternetGateway {
|
||||
globalNs := models.Nameserver{
|
||||
MatchDomain: ".",
|
||||
}
|
||||
for _, nsI := range GlobalNsList {
|
||||
globalNs.IPs = append(globalNs.IPs, nsI.IPs...)
|
||||
}
|
||||
returnNsLi = append(returnNsLi, globalNs)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func getNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
|
||||
if h.DNS != "yes" {
|
||||
return
|
||||
}
|
||||
for _, nodeID := range h.Nodes {
|
||||
node, err := GetNodeByID(nodeID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ns := &schema.Nameserver{
|
||||
NetworkID: node.Network,
|
||||
}
|
||||
nsLi, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
|
||||
for _, nsI := range nsLi {
|
||||
if !nsI.Status {
|
||||
continue
|
||||
}
|
||||
_, all := nsI.Tags["*"]
|
||||
if all {
|
||||
for _, matchDomain := range nsI.MatchDomains {
|
||||
returnNsLi = append(returnNsLi, models.Nameserver{
|
||||
IPs: nsI.Servers,
|
||||
MatchDomain: matchDomain,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if _, ok := nsI.Nodes[node.ID.String()]; ok {
|
||||
for _, matchDomain := range nsI.MatchDomains {
|
||||
returnNsLi = append(returnNsLi, models.Nameserver{
|
||||
IPs: nsI.Servers,
|
||||
MatchDomain: matchDomain,
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
if node.IsInternetGateway {
|
||||
globalNs := models.Nameserver{
|
||||
MatchDomain: ".",
|
||||
}
|
||||
for _, nsI := range GlobalNsList {
|
||||
globalNs.IPs = append(globalNs.IPs, nsI.IPs...)
|
||||
}
|
||||
returnNsLi = append(returnNsLi, globalNs)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// IsValidMatchDomain reports whether s is a valid "match domain".
|
||||
// Rules (simple/ASCII):
|
||||
// - "~." is allowed (match all).
|
||||
// - Optional leading "~" allowed (e.g., "~example.com").
|
||||
// - Optional single trailing "." allowed (FQDN form).
|
||||
// - No wildcards "*", no leading ".", no underscores.
|
||||
// - Labels: letters/digits/hyphen (LDH), 1–63 chars, no leading/trailing hyphen.
|
||||
// - Total length (without trailing dot) ≤ 253.
|
||||
func IsValidMatchDomain(s string) bool {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
if s == "~." { // special case: match-all
|
||||
return true
|
||||
}
|
||||
|
||||
// Strip optional leading "~"
|
||||
if strings.HasPrefix(s, "~") {
|
||||
s = s[1:]
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Allow exactly one trailing dot
|
||||
if strings.HasSuffix(s, ".") {
|
||||
s = s[:len(s)-1]
|
||||
if s == "" {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Disallow leading dot, wildcards, underscores
|
||||
if strings.HasPrefix(s, ".") || strings.Contains(s, "*") || strings.Contains(s, "_") {
|
||||
return false
|
||||
}
|
||||
|
||||
// Lowercase for ASCII checks
|
||||
s = strings.ToLower(s)
|
||||
|
||||
// Length check
|
||||
if len(s) > 253 {
|
||||
return false
|
||||
}
|
||||
|
||||
// Label regex: LDH, 1–63, no leading/trailing hyphen
|
||||
reLabel := regexp.MustCompile(`^[a-z0-9](?:[a-z0-9-]{0,61}[a-z0-9])?$`)
|
||||
|
||||
parts := strings.Split(s, ".")
|
||||
for _, lbl := range parts {
|
||||
if len(lbl) == 0 || len(lbl) > 63 {
|
||||
return false
|
||||
}
|
||||
if !reLabel.MatchString(lbl) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
|
@ -142,6 +142,7 @@ func GetPeerUpdateForHost(network string, host *models.Host, allNodes []models.N
|
|||
NodePeers: []wgtypes.PeerConfig{},
|
||||
HostNetworkInfo: models.HostInfoMap{},
|
||||
ServerConfig: GetServerInfo(),
|
||||
DnsNameservers: GetNameserversForHost(host),
|
||||
}
|
||||
if host.DNS == "no" {
|
||||
hostPeerUpdate.ManageDNS = false
|
||||
|
|
|
@ -38,6 +38,24 @@ func UpsertServerSettings(s models.ServerSettings) error {
|
|||
s.BasicAuth = true
|
||||
}
|
||||
|
||||
var userFilters []string
|
||||
for _, userFilter := range s.UserFilters {
|
||||
userFilter = strings.TrimSpace(userFilter)
|
||||
if userFilter != "" {
|
||||
userFilters = append(userFilters, userFilter)
|
||||
}
|
||||
}
|
||||
s.UserFilters = userFilters
|
||||
|
||||
var groupFilters []string
|
||||
for _, groupFilter := range s.GroupFilters {
|
||||
groupFilter = strings.TrimSpace(groupFilter)
|
||||
if groupFilter != "" {
|
||||
groupFilters = append(groupFilters, groupFilter)
|
||||
}
|
||||
}
|
||||
s.GroupFilters = groupFilters
|
||||
|
||||
data, err := json.Marshal(s)
|
||||
if err != nil {
|
||||
return err
|
||||
|
|
|
@ -37,9 +37,81 @@ func Run() {
|
|||
updateAcls()
|
||||
logic.MigrateToGws()
|
||||
migrateToEgressV1()
|
||||
migrateNameservers()
|
||||
resync()
|
||||
}
|
||||
|
||||
func migrateNameservers() {
|
||||
nets, _ := logic.GetNetworks()
|
||||
user, err := logic.GetSuperAdmin()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
for _, netI := range nets {
|
||||
if len(netI.NameServers) > 0 {
|
||||
ns := schema.Nameserver{
|
||||
ID: uuid.NewString(),
|
||||
Name: "upstream nameservers",
|
||||
NetworkID: netI.NetID,
|
||||
Servers: []string{},
|
||||
MatchAll: true,
|
||||
MatchDomains: []string{"."},
|
||||
Tags: datatypes.JSONMap{
|
||||
"*": struct{}{},
|
||||
},
|
||||
Nodes: make(datatypes.JSONMap),
|
||||
Status: true,
|
||||
CreatedBy: user.UserName,
|
||||
}
|
||||
for _, ip := range netI.NameServers {
|
||||
ns.Servers = append(ns.Servers, ip)
|
||||
}
|
||||
ns.Create(db.WithContext(context.TODO()))
|
||||
netI.NameServers = []string{}
|
||||
logic.SaveNetwork(&netI)
|
||||
}
|
||||
}
|
||||
nodes, _ := logic.GetAllNodes()
|
||||
for _, node := range nodes {
|
||||
if !node.IsGw {
|
||||
continue
|
||||
}
|
||||
if node.IngressDNS != "" {
|
||||
if (node.Address.IP != nil && node.Address.IP.String() == node.IngressDNS) ||
|
||||
(node.Address6.IP != nil && node.Address6.IP.String() == node.IngressDNS) {
|
||||
continue
|
||||
}
|
||||
if node.IngressDNS == "8.8.8.8" || node.IngressDNS == "1.1.1.1" || node.IngressDNS == "9.9.9.9" {
|
||||
continue
|
||||
}
|
||||
h, err := logic.GetHost(node.HostID.String())
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ns := schema.Nameserver{
|
||||
ID: uuid.NewString(),
|
||||
Name: fmt.Sprintf("%s gw nameservers", h.Name),
|
||||
NetworkID: node.Network,
|
||||
Servers: []string{node.IngressDNS},
|
||||
MatchAll: true,
|
||||
MatchDomains: []string{"."},
|
||||
Nodes: datatypes.JSONMap{
|
||||
node.ID.String(): struct{}{},
|
||||
},
|
||||
Tags: make(datatypes.JSONMap),
|
||||
Status: true,
|
||||
CreatedBy: user.UserName,
|
||||
}
|
||||
ns.Create(db.WithContext(context.TODO()))
|
||||
node.IngressDNS = ""
|
||||
logic.UpsertNode(&node)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// removes if any stale configurations from previous run.
|
||||
func resync() {
|
||||
|
||||
|
|
|
@ -47,3 +47,13 @@ type DNSEntry struct {
|
|||
Name string `json:"name" validate:"required,name_unique,min=1,max=192,whitespace"`
|
||||
Network string `json:"network" validate:"network_exists"`
|
||||
}
|
||||
|
||||
type NameserverReq struct {
|
||||
Name string `json:"name"`
|
||||
Network string `json:"network"`
|
||||
Description string ` json:"description"`
|
||||
Servers []string `json:"servers"`
|
||||
MatchDomain string `json:"match_domain"`
|
||||
Tags []string `json:"tags"`
|
||||
Status bool `gorm:"status" json:"status"`
|
||||
}
|
||||
|
|
|
@ -55,6 +55,7 @@ const (
|
|||
DashboardSub SubjectType = "DASHBOARD"
|
||||
EnrollmentKeySub SubjectType = "ENROLLMENT_KEY"
|
||||
ClientAppSub SubjectType = "CLIENT-APP"
|
||||
NameserverSub SubjectType = "NAMESERVER"
|
||||
)
|
||||
|
||||
func (sub SubjectType) String() string {
|
||||
|
|
|
@ -28,10 +28,16 @@ type HostPeerUpdate struct {
|
|||
FwUpdate FwUpdate `json:"fw_update"`
|
||||
ReplacePeers bool `json:"replace_peers"`
|
||||
NameServers []string `json:"name_servers"`
|
||||
DnsNameservers []Nameserver `json:"dns_nameservers"`
|
||||
ServerConfig
|
||||
OldPeerUpdateFields
|
||||
}
|
||||
|
||||
type Nameserver struct {
|
||||
IPs []string `json:"ips"`
|
||||
MatchDomain string `json:"match_domain"`
|
||||
}
|
||||
|
||||
type OldPeerUpdateFields struct {
|
||||
NodePeers []wgtypes.PeerConfig `json:"peers" bson:"peers" yaml:"peers"`
|
||||
OldPeers []wgtypes.PeerConfig `json:"Peers"`
|
||||
|
|
|
@ -46,6 +46,7 @@ type UserRemoteGws struct {
|
|||
Status NodeStatus `json:"status"`
|
||||
DnsAddress string `json:"dns_address"`
|
||||
Addresses string `json:"addresses"`
|
||||
MatchDomains []string `json:"match_domains"`
|
||||
}
|
||||
|
||||
// UserRAGs - struct for user access gws
|
||||
|
@ -254,6 +255,7 @@ type HostPull struct {
|
|||
DefaultGwIp net.IP `json:"default_gw_ip"`
|
||||
IsInternetGw bool `json:"is_inet_gw"`
|
||||
EndpointDetection bool `json:"endpoint_detection"`
|
||||
DnsNameservers []Nameserver `json:"dns_nameservers"`
|
||||
}
|
||||
|
||||
type DefaultGwInfo struct {
|
||||
|
|
|
@ -113,6 +113,7 @@ func PublishSingleHostPeerUpdate(host *models.Host, allNodes []models.Node, dele
|
|||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, nodeID := range host.Nodes {
|
||||
|
||||
node, err := logic.GetNodeByID(nodeID)
|
||||
|
|
|
@ -3,6 +3,10 @@ package auth
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/database"
|
||||
"github.com/gravitl/netmaker/logger"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
|
@ -12,9 +16,6 @@ import (
|
|||
"github.com/gravitl/netmaker/pro/idp/google"
|
||||
"github.com/gravitl/netmaker/pro/idp/okta"
|
||||
proLogic "github.com/gravitl/netmaker/pro/logic"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
var (
|
||||
|
@ -85,15 +86,23 @@ func SyncFromIDP() error {
|
|||
}
|
||||
|
||||
if settings.AuthProvider != "" && idpClient != nil {
|
||||
idpUsers, err = idpClient.GetUsers()
|
||||
idpUsers, err = idpClient.GetUsers(settings.UserFilters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
idpGroups, err = idpClient.GetGroups()
|
||||
idpGroups, err = idpClient.GetGroups(settings.GroupFilters)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(settings.GroupFilters) > 0 {
|
||||
idpUsers = filterUsersByGroupMembership(idpUsers, idpGroups)
|
||||
}
|
||||
|
||||
if len(settings.UserFilters) > 0 {
|
||||
idpGroups = filterGroupsByMembers(idpGroups, idpUsers)
|
||||
}
|
||||
}
|
||||
|
||||
err = syncUsers(idpUsers)
|
||||
|
@ -316,3 +325,64 @@ func syncGroups(idpGroups []idp.Group) error {
|
|||
|
||||
return nil
|
||||
}
|
||||
|
||||
func filterUsersByGroupMembership(idpUsers []idp.User, idpGroups []idp.Group) []idp.User {
|
||||
usersMap := make(map[string]int)
|
||||
for i, user := range idpUsers {
|
||||
usersMap[user.ID] = i
|
||||
}
|
||||
|
||||
filteredUsersMap := make(map[string]int)
|
||||
for _, group := range idpGroups {
|
||||
for _, member := range group.Members {
|
||||
if userIdx, ok := usersMap[member]; ok {
|
||||
// user at index `userIdx` is a member of at least one of the
|
||||
// groups in the `idpGroups` list, so we keep it.
|
||||
filteredUsersMap[member] = userIdx
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i := 0
|
||||
filteredUsers := make([]idp.User, len(filteredUsersMap))
|
||||
for _, userIdx := range filteredUsersMap {
|
||||
filteredUsers[i] = idpUsers[userIdx]
|
||||
i++
|
||||
}
|
||||
|
||||
return filteredUsers
|
||||
}
|
||||
|
||||
func filterGroupsByMembers(idpGroups []idp.Group, idpUsers []idp.User) []idp.Group {
|
||||
usersMap := make(map[string]int)
|
||||
for i, user := range idpUsers {
|
||||
usersMap[user.ID] = i
|
||||
}
|
||||
|
||||
filteredGroupsMap := make(map[int]bool)
|
||||
for i, group := range idpGroups {
|
||||
var members []string
|
||||
for _, member := range group.Members {
|
||||
if _, ok := usersMap[member]; ok {
|
||||
members = append(members, member)
|
||||
}
|
||||
|
||||
if len(members) > 0 {
|
||||
// the group at index `i` has members from the `idpUsers` list,
|
||||
// so we keep it.
|
||||
filteredGroupsMap[i] = true
|
||||
// filter out members that were not provided in the `idpUsers` list.
|
||||
idpGroups[i].Members = members
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
i := 0
|
||||
filteredGroups := make([]idp.Group, len(filteredGroupsMap))
|
||||
for groupIdx := range filteredGroupsMap {
|
||||
filteredGroups[i] = idpGroups[groupIdx]
|
||||
i++
|
||||
}
|
||||
|
||||
return filteredGroups
|
||||
}
|
||||
|
|
|
@ -1171,11 +1171,7 @@ func getRemoteAccessGatewayConf(w http.ResponseWriter, r *http.Request) {
|
|||
userConf.OwnerID = user.UserName
|
||||
userConf.RemoteAccessClientID = req.RemoteAccessClientID
|
||||
userConf.IngressGatewayID = node.ID.String()
|
||||
|
||||
// set extclient dns to ingressdns if extclient dns is not explicitly set
|
||||
if (userConf.DNS == "") && (node.IngressDNS != "") {
|
||||
userConf.DNS = node.IngressDNS
|
||||
}
|
||||
logic.SetDNSOnWgConfig(&node, &userConf)
|
||||
|
||||
userConf.Network = node.Network
|
||||
host, err := logic.GetHost(node.HostID.String())
|
||||
|
@ -1301,9 +1297,8 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
|
||||
gws := userGws[node.Network]
|
||||
if extClient.DNS == "" {
|
||||
extClient.DNS = node.IngressDNS
|
||||
}
|
||||
|
||||
logic.SetDNSOnWgConfig(&node, &extClient)
|
||||
|
||||
extClient.IngressGatewayEndpoint = utils.GetExtClientEndpoint(
|
||||
host.EndpointIP,
|
||||
|
@ -1311,7 +1306,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
|
|||
logic.GetPeerListenPort(host),
|
||||
)
|
||||
extClient.AllowedIPs = logic.GetExtclientAllowedIPs(extClient)
|
||||
gws = append(gws, models.UserRemoteGws{
|
||||
gw := models.UserRemoteGws{
|
||||
GwID: node.ID.String(),
|
||||
GWName: host.Name,
|
||||
Network: node.Network,
|
||||
|
@ -1326,7 +1321,14 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
|
|||
Status: node.Status,
|
||||
DnsAddress: node.IngressDNS,
|
||||
Addresses: utils.NoEmptyStringToCsv(node.Address.String(), node.Address6.String()),
|
||||
})
|
||||
}
|
||||
if !node.IsInternetGateway {
|
||||
hNs := logic.GetNameserversForNode(&node)
|
||||
for _, nsI := range hNs {
|
||||
gw.MatchDomains = append(gw.MatchDomains, nsI.MatchDomain)
|
||||
}
|
||||
}
|
||||
gws = append(gws, gw)
|
||||
userGws[node.Network] = gws
|
||||
delete(userGwNodes, node.ID.String())
|
||||
}
|
||||
|
@ -1357,7 +1359,7 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
|
|||
}
|
||||
gws := userGws[node.Network]
|
||||
|
||||
gws = append(gws, models.UserRemoteGws{
|
||||
gw := models.UserRemoteGws{
|
||||
GwID: node.ID.String(),
|
||||
GWName: host.Name,
|
||||
Network: node.Network,
|
||||
|
@ -1370,7 +1372,14 @@ func getUserRemoteAccessGwsV1(w http.ResponseWriter, r *http.Request) {
|
|||
Status: node.Status,
|
||||
DnsAddress: node.IngressDNS,
|
||||
Addresses: utils.NoEmptyStringToCsv(node.Address.String(), node.Address6.String()),
|
||||
})
|
||||
}
|
||||
if !node.IsInternetGateway {
|
||||
hNs := logic.GetNameserversForNode(&node)
|
||||
for _, nsI := range hNs {
|
||||
gw.MatchDomains = append(gw.MatchDomains, nsI.MatchDomain)
|
||||
}
|
||||
}
|
||||
gws = append(gws, gw)
|
||||
userGws[node.Network] = gws
|
||||
}
|
||||
|
||||
|
|
|
@ -4,10 +4,11 @@ import (
|
|||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/pro/idp"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/pro/idp"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
|
@ -26,89 +27,103 @@ func NewAzureEntraIDClient() *Client {
|
|||
}
|
||||
}
|
||||
|
||||
func (a *Client) GetUsers() ([]idp.User, error) {
|
||||
func (a *Client) GetUsers(filters []string) ([]idp.User, error) {
|
||||
accessToken, err := a.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("GET", "https://graph.microsoft.com/v1.0/users?$select=id,userPrincipalName,displayName,accountEnabled", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
getUsersURL := "https://graph.microsoft.com/v1.0/users?$select=id,userPrincipalName,displayName,accountEnabled"
|
||||
if len(filters) > 0 {
|
||||
getUsersURL += "&" + buildPrefixFilter("userPrincipalName", filters)
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Add("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
var users getUsersResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&users)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
retval := make([]idp.User, len(users.Value))
|
||||
for i, user := range users.Value {
|
||||
retval[i] = idp.User{
|
||||
ID: user.Id,
|
||||
Username: user.UserPrincipalName,
|
||||
DisplayName: user.DisplayName,
|
||||
AccountDisabled: !user.AccountEnabled,
|
||||
var retval []idp.User
|
||||
for getUsersURL != "" {
|
||||
req, err := http.NewRequest("GET", getUsersURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Add("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var users getUsersResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&users)
|
||||
_ = resp.Body.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, user := range users.Value {
|
||||
retval = append(retval, idp.User{
|
||||
ID: user.Id,
|
||||
Username: user.UserPrincipalName,
|
||||
DisplayName: user.DisplayName,
|
||||
AccountDisabled: !user.AccountEnabled,
|
||||
})
|
||||
}
|
||||
|
||||
getUsersURL = users.NextLink
|
||||
}
|
||||
|
||||
return retval, nil
|
||||
}
|
||||
|
||||
func (a *Client) GetGroups() ([]idp.Group, error) {
|
||||
func (a *Client) GetGroups(filters []string) ([]idp.Group, error) {
|
||||
accessToken, err := a.getAccessToken()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
client := &http.Client{}
|
||||
req, err := http.NewRequest("GET", "https://graph.microsoft.com/v1.0/groups?$select=id,displayName&$expand=members($select=id)", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
getGroupsURL := "https://graph.microsoft.com/v1.0/groups?$select=id,displayName&$expand=members($select=id)"
|
||||
if len(filters) > 0 {
|
||||
getGroupsURL += "&" + buildPrefixFilter("displayName", filters)
|
||||
}
|
||||
|
||||
req.Header.Add("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Add("Accept", "application/json")
|
||||
var retval []idp.Group
|
||||
for getGroupsURL != "" {
|
||||
req, err := http.NewRequest("GET", getGroupsURL, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
req.Header.Add("Authorization", "Bearer "+accessToken)
|
||||
req.Header.Add("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var groups getGroupsResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&groups)
|
||||
_ = resp.Body.Close()
|
||||
}()
|
||||
|
||||
var groups getGroupsResponse
|
||||
err = json.NewDecoder(resp.Body).Decode(&groups)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
retval := make([]idp.Group, len(groups.Value))
|
||||
for i, group := range groups.Value {
|
||||
retvalMembers := make([]string, len(group.Members))
|
||||
for j, member := range group.Members {
|
||||
retvalMembers[j] = member.Id
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
retval[i] = idp.Group{
|
||||
ID: group.Id,
|
||||
Name: group.DisplayName,
|
||||
Members: retvalMembers,
|
||||
for _, group := range groups.Value {
|
||||
retvalMembers := make([]string, len(group.Members))
|
||||
for j, member := range group.Members {
|
||||
retvalMembers[j] = member.Id
|
||||
}
|
||||
|
||||
retval = append(retval, idp.Group{
|
||||
ID: group.Id,
|
||||
Name: group.DisplayName,
|
||||
Members: retvalMembers,
|
||||
})
|
||||
}
|
||||
|
||||
getGroupsURL = groups.NextLink
|
||||
}
|
||||
|
||||
return retval, nil
|
||||
|
@ -144,6 +159,18 @@ func (a *Client) getAccessToken() (string, error) {
|
|||
return "", errors.New("failed to get access token")
|
||||
}
|
||||
|
||||
func buildPrefixFilter(field string, prefixes []string) string {
|
||||
if len(prefixes) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(prefixes) == 1 {
|
||||
return fmt.Sprintf("$filter=startswith(%s,'%s')", field, prefixes[0])
|
||||
}
|
||||
|
||||
return buildPrefixFilter(field, prefixes[:1]) + "%20or%20" + buildPrefixFilter(field, prefixes[1:])
|
||||
}
|
||||
|
||||
type getUsersResponse struct {
|
||||
OdataContext string `json:"@odata.context"`
|
||||
Value []struct {
|
||||
|
@ -152,6 +179,7 @@ type getUsersResponse struct {
|
|||
DisplayName string `json:"displayName"`
|
||||
AccountEnabled bool `json:"accountEnabled"`
|
||||
} `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
|
||||
type getGroupsResponse struct {
|
||||
|
@ -164,4 +192,5 @@ type getGroupsResponse struct {
|
|||
Id string `json:"id"`
|
||||
} `json:"members"`
|
||||
} `json:"value"`
|
||||
NextLink string `json:"@odata.nextLink"`
|
||||
}
|
||||
|
|
|
@ -4,6 +4,8 @@ import (
|
|||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/pro/idp"
|
||||
admindir "google.golang.org/api/admin/directory/v1"
|
||||
|
@ -59,13 +61,28 @@ func NewGoogleWorkspaceClient() (*Client, error) {
|
|||
}, nil
|
||||
}
|
||||
|
||||
func (g *Client) GetUsers() ([]idp.User, error) {
|
||||
func (g *Client) GetUsers(filters []string) ([]idp.User, error) {
|
||||
var retval []idp.User
|
||||
err := g.service.Users.List().
|
||||
Customer("my_customer").
|
||||
Fields("users(id,primaryEmail,name,suspended,archived)", "nextPageToken").
|
||||
Pages(context.TODO(), func(users *admindir.Users) error {
|
||||
for _, user := range users.Users {
|
||||
var keep bool
|
||||
if len(filters) > 0 {
|
||||
for _, filter := range filters {
|
||||
if strings.HasPrefix(user.PrimaryEmail, filter) {
|
||||
keep = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
keep = true
|
||||
}
|
||||
|
||||
if !keep {
|
||||
continue
|
||||
}
|
||||
|
||||
retval = append(retval, idp.User{
|
||||
ID: user.Id,
|
||||
Username: user.PrimaryEmail,
|
||||
|
@ -81,13 +98,28 @@ func (g *Client) GetUsers() ([]idp.User, error) {
|
|||
return retval, err
|
||||
}
|
||||
|
||||
func (g *Client) GetGroups() ([]idp.Group, error) {
|
||||
func (g *Client) GetGroups(filters []string) ([]idp.Group, error) {
|
||||
var retval []idp.Group
|
||||
err := g.service.Groups.List().
|
||||
Customer("my_customer").
|
||||
Fields("groups(id,name)", "nextPageToken").
|
||||
Pages(context.TODO(), func(groups *admindir.Groups) error {
|
||||
for _, group := range groups.Groups {
|
||||
var keep bool
|
||||
if len(filters) > 0 {
|
||||
for _, filter := range filters {
|
||||
if strings.HasPrefix(group.Name, filter) {
|
||||
keep = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
keep = true
|
||||
}
|
||||
|
||||
if !keep {
|
||||
continue
|
||||
}
|
||||
|
||||
var retvalMembers []string
|
||||
err := g.service.Members.List(group.Id).
|
||||
Fields("members(id)", "nextPageToken").
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
package idp
|
||||
|
||||
type Client interface {
|
||||
GetUsers() ([]User, error)
|
||||
GetGroups() ([]Group, error)
|
||||
GetUsers(filters []string) ([]User, error)
|
||||
GetGroups(filters []string) ([]Group, error)
|
||||
}
|
||||
|
||||
type User struct {
|
||||
|
|
|
@ -3,6 +3,7 @@ package okta
|
|||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/pro/idp"
|
||||
"github.com/okta/okta-sdk-golang/v5/okta"
|
||||
|
@ -42,12 +43,14 @@ func (o *Client) Verify() error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (o *Client) GetUsers() ([]idp.User, error) {
|
||||
func (o *Client) GetUsers(filters []string) ([]idp.User, error) {
|
||||
var retval []idp.User
|
||||
var allUsersFetched bool
|
||||
|
||||
for !allUsersFetched {
|
||||
users, resp, err := o.client.UserAPI.ListUsers(context.TODO()).Execute()
|
||||
users, resp, err := o.client.UserAPI.ListUsers(context.TODO()).
|
||||
Search(buildPrefixFilter("profile.login", filters)).
|
||||
Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -81,12 +84,14 @@ func (o *Client) GetUsers() ([]idp.User, error) {
|
|||
return retval, nil
|
||||
}
|
||||
|
||||
func (o *Client) GetGroups() ([]idp.Group, error) {
|
||||
func (o *Client) GetGroups(filters []string) ([]idp.Group, error) {
|
||||
var retval []idp.Group
|
||||
var allGroupsFetched bool
|
||||
|
||||
for !allGroupsFetched {
|
||||
groups, resp, err := o.client.GroupAPI.ListGroups(context.TODO()).Execute()
|
||||
groups, resp, err := o.client.GroupAPI.ListGroups(context.TODO()).
|
||||
Search(buildPrefixFilter("profile.name", filters)).
|
||||
Execute()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -122,3 +127,15 @@ func (o *Client) GetGroups() ([]idp.Group, error) {
|
|||
|
||||
return retval, nil
|
||||
}
|
||||
|
||||
func buildPrefixFilter(field string, prefixes []string) string {
|
||||
if len(prefixes) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if len(prefixes) == 1 {
|
||||
return fmt.Sprintf("%s sw \"%s\"", field, prefixes[0])
|
||||
}
|
||||
|
||||
return buildPrefixFilter(field, prefixes[:1]) + " or " + buildPrefixFilter(field, prefixes[1:])
|
||||
}
|
||||
|
|
|
@ -155,6 +155,9 @@ func InitPro() {
|
|||
logic.GetFwRulesForNodeAndPeerOnGw = proLogic.GetFwRulesForNodeAndPeerOnGw
|
||||
logic.GetFwRulesForUserNodesOnGw = proLogic.GetFwRulesForUserNodesOnGw
|
||||
logic.GetHostLocInfo = proLogic.GetHostLocInfo
|
||||
logic.GetNameserversForHost = proLogic.GetNameserversForHost
|
||||
logic.GetNameserversForNode = proLogic.GetNameserversForNode
|
||||
logic.ValidateNameserverReq = proLogic.ValidateNameserverReq
|
||||
|
||||
}
|
||||
|
||||
|
|
171
pro/logic/dns.go
Normal file
171
pro/logic/dns.go
Normal file
|
@ -0,0 +1,171 @@
|
|||
package logic
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"github.com/gravitl/netmaker/logic"
|
||||
"github.com/gravitl/netmaker/models"
|
||||
"github.com/gravitl/netmaker/schema"
|
||||
)
|
||||
|
||||
func ValidateNameserverReq(ns schema.Nameserver) error {
|
||||
if ns.Name == "" {
|
||||
return errors.New("name is required")
|
||||
}
|
||||
if ns.NetworkID == "" {
|
||||
return errors.New("network is required")
|
||||
}
|
||||
if len(ns.Servers) == 0 {
|
||||
return errors.New("atleast one nameserver should be specified")
|
||||
}
|
||||
if !ns.MatchAll && len(ns.MatchDomains) == 0 {
|
||||
return errors.New("atleast one match domain is required")
|
||||
}
|
||||
if !ns.MatchAll {
|
||||
for _, matchDomain := range ns.MatchDomains {
|
||||
if !logic.IsValidMatchDomain(matchDomain) {
|
||||
return errors.New("invalid match domain")
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(ns.Tags) > 0 {
|
||||
for tagI := range ns.Tags {
|
||||
if tagI == "*" {
|
||||
continue
|
||||
}
|
||||
_, err := GetTag(models.TagID(tagI))
|
||||
if err != nil {
|
||||
return errors.New("invalid tag")
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetNameserversForNode(node *models.Node) (returnNsLi []models.Nameserver) {
|
||||
ns := &schema.Nameserver{
|
||||
NetworkID: node.Network,
|
||||
}
|
||||
nsLi, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
|
||||
for _, nsI := range nsLi {
|
||||
if !nsI.Status {
|
||||
continue
|
||||
}
|
||||
_, all := nsI.Tags["*"]
|
||||
if all {
|
||||
for _, matchDomain := range nsI.MatchDomains {
|
||||
returnNsLi = append(returnNsLi, models.Nameserver{
|
||||
IPs: nsI.Servers,
|
||||
MatchDomain: matchDomain,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
foundTag := false
|
||||
for tagI := range node.Tags {
|
||||
if _, ok := nsI.Tags[tagI.String()]; ok {
|
||||
for _, matchDomain := range nsI.MatchDomains {
|
||||
returnNsLi = append(returnNsLi, models.Nameserver{
|
||||
IPs: nsI.Servers,
|
||||
MatchDomain: matchDomain,
|
||||
})
|
||||
}
|
||||
foundTag = true
|
||||
}
|
||||
if foundTag {
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundTag {
|
||||
continue
|
||||
}
|
||||
if _, ok := nsI.Nodes[node.ID.String()]; ok {
|
||||
for _, matchDomain := range nsI.MatchDomains {
|
||||
returnNsLi = append(returnNsLi, models.Nameserver{
|
||||
IPs: nsI.Servers,
|
||||
MatchDomain: matchDomain,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
if node.IsInternetGateway {
|
||||
globalNs := models.Nameserver{
|
||||
MatchDomain: ".",
|
||||
}
|
||||
for _, nsI := range logic.GlobalNsList {
|
||||
globalNs.IPs = append(globalNs.IPs, nsI.IPs...)
|
||||
}
|
||||
returnNsLi = append(returnNsLi, globalNs)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func GetNameserversForHost(h *models.Host) (returnNsLi []models.Nameserver) {
|
||||
if h.DNS != "yes" {
|
||||
return
|
||||
}
|
||||
for _, nodeID := range h.Nodes {
|
||||
node, err := logic.GetNodeByID(nodeID)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
ns := &schema.Nameserver{
|
||||
NetworkID: node.Network,
|
||||
}
|
||||
nsLi, _ := ns.ListByNetwork(db.WithContext(context.TODO()))
|
||||
for _, nsI := range nsLi {
|
||||
if !nsI.Status {
|
||||
continue
|
||||
}
|
||||
_, all := nsI.Tags["*"]
|
||||
if all {
|
||||
for _, matchDomain := range nsI.MatchDomains {
|
||||
returnNsLi = append(returnNsLi, models.Nameserver{
|
||||
IPs: nsI.Servers,
|
||||
MatchDomain: matchDomain,
|
||||
})
|
||||
}
|
||||
continue
|
||||
}
|
||||
foundTag := false
|
||||
for tagI := range node.Tags {
|
||||
if _, ok := nsI.Tags[tagI.String()]; ok {
|
||||
for _, matchDomain := range nsI.MatchDomains {
|
||||
returnNsLi = append(returnNsLi, models.Nameserver{
|
||||
IPs: nsI.Servers,
|
||||
MatchDomain: matchDomain,
|
||||
})
|
||||
}
|
||||
foundTag = true
|
||||
}
|
||||
if foundTag {
|
||||
break
|
||||
}
|
||||
}
|
||||
if foundTag {
|
||||
continue
|
||||
}
|
||||
if _, ok := nsI.Nodes[node.ID.String()]; ok {
|
||||
for _, matchDomain := range nsI.MatchDomains {
|
||||
returnNsLi = append(returnNsLi, models.Nameserver{
|
||||
IPs: nsI.Servers,
|
||||
MatchDomain: matchDomain,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
if node.IsInternetGateway {
|
||||
globalNs := models.Nameserver{
|
||||
MatchDomain: ".",
|
||||
}
|
||||
for _, nsI := range logic.GlobalNsList {
|
||||
globalNs.IPs = append(globalNs.IPs, nsI.IPs...)
|
||||
}
|
||||
returnNsLi = append(returnNsLi, globalNs)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
58
schema/dns.go
Normal file
58
schema/dns.go
Normal file
|
@ -0,0 +1,58 @@
|
|||
package schema
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/gravitl/netmaker/db"
|
||||
"gorm.io/datatypes"
|
||||
)
|
||||
|
||||
type Nameserver struct {
|
||||
ID string `gorm:"primaryKey" json:"id"`
|
||||
Name string `gorm:"name" json:"name"`
|
||||
NetworkID string `gorm:"network_id" json:"network_id"`
|
||||
Description string `gorm:"description" json:"description"`
|
||||
Servers datatypes.JSONSlice[string] `gorm:"servers" json:"servers"`
|
||||
MatchAll bool `gorm:"match_all" json:"match_all"`
|
||||
MatchDomains datatypes.JSONSlice[string] `gorm:"match_domains" json:"match_domains"`
|
||||
Tags datatypes.JSONMap `gorm:"tags" json:"tags"`
|
||||
Nodes datatypes.JSONMap `gorm:"nodes" json:"nodes"`
|
||||
Status bool `gorm:"status" json:"status"`
|
||||
CreatedBy string `gorm:"created_by" json:"created_by"`
|
||||
CreatedAt time.Time `gorm:"created_at" json:"created_at"`
|
||||
UpdatedAt time.Time `gorm:"updated_at" json:"updated_at"`
|
||||
}
|
||||
|
||||
func (ns *Nameserver) Get(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&Nameserver{}).First(&ns).Where("id = ?", ns.ID).Error
|
||||
}
|
||||
|
||||
func (ns *Nameserver) Update(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&Nameserver{}).Where("id = ?", ns.ID).Updates(&ns).Error
|
||||
}
|
||||
|
||||
func (ns *Nameserver) Create(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&Nameserver{}).Create(&ns).Error
|
||||
}
|
||||
|
||||
func (ns *Nameserver) ListByNetwork(ctx context.Context) (dnsli []Nameserver, err error) {
|
||||
err = db.FromContext(ctx).Model(&Nameserver{}).Where("network_id = ?", ns.NetworkID).Find(&dnsli).Error
|
||||
return
|
||||
}
|
||||
|
||||
func (ns *Nameserver) Delete(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&Nameserver{}).Where("id = ?", ns.ID).Delete(&ns).Error
|
||||
}
|
||||
|
||||
func (ns *Nameserver) UpdateStatus(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&Nameserver{}).Where("id = ?", ns.ID).Updates(map[string]any{
|
||||
"status": ns.Status,
|
||||
}).Error
|
||||
}
|
||||
|
||||
func (ns *Nameserver) UpdateMatchAll(ctx context.Context) error {
|
||||
return db.FromContext(ctx).Model(&Nameserver{}).Where("id = ?", ns.ID).Updates(map[string]any{
|
||||
"match_all": ns.MatchAll,
|
||||
}).Error
|
||||
}
|
|
@ -7,5 +7,6 @@ func ListModels() []interface{} {
|
|||
&Egress{},
|
||||
&UserAccessToken{},
|
||||
&Event{},
|
||||
&Nameserver{},
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Reference in a new issue