dnscontrol/pkg/dnsgraph/dnsgraph.go
2023-12-12 10:36:39 -05:00

156 lines
3.3 KiB
Go

package dnsgraph
import "github.com/StackExchange/dnscontrol/v4/pkg/dnstree"
type edgeDirection uint8
const (
// IncomingEdge is an edge inbound.
IncomingEdge edgeDirection = iota
// OutgoingEdge is an edge outbound.
OutgoingEdge
)
// Edge an edge on the graph.
type Edge[T Graphable] struct {
Dependency Dependency
Node *Node[T]
Direction edgeDirection
}
// Edges a list of edges.
type Edges[T Graphable] []Edge[T]
// Node a node in the graph.
type Node[T Graphable] struct {
Data T
Edges Edges[T]
}
type dnsGraphNodes[T Graphable] []*Node[T]
// Graph a graph.
type Graph[T Graphable] struct {
All dnsGraphNodes[T]
Tree *dnstree.DomainTree[dnsGraphNodes[T]]
}
// CreateGraph returns a graph.
func CreateGraph[T Graphable](entries []T) *Graph[T] {
graph := &Graph[T]{
All: dnsGraphNodes[T]{},
Tree: dnstree.Create[dnsGraphNodes[T]](),
}
for _, data := range entries {
graph.AddNode(data)
}
for _, sourceNode := range graph.All {
for _, dependency := range sourceNode.Data.GetDependencies() {
graph.AddEdge(sourceNode, dependency)
}
}
return graph
}
// RemoveNode removes a node from a graph.
func (graph *Graph[T]) RemoveNode(toRemove *Node[T]) {
for _, edge := range toRemove.Edges {
edge.Node.Edges = edge.Node.Edges.RemoveNode(toRemove)
}
graph.All = graph.All.RemoveNode(toRemove)
nodes := graph.Tree.Get(toRemove.Data.GetName())
if nodes != nil {
nodes = nodes.RemoveNode(toRemove)
graph.Tree.Set(toRemove.Data.GetName(), nodes)
}
}
// AddNode adds a node to a graph.
func (graph *Graph[T]) AddNode(data T) {
nodes := graph.Tree.Get(data.GetName())
node := &Node[T]{
Data: data,
Edges: Edges[T]{},
}
if nodes == nil {
nodes = dnsGraphNodes[T]{}
}
nodes = append(nodes, node)
graph.All = append(graph.All, node)
graph.Tree.Set(data.GetName(), nodes)
}
// AddEdge adds an edge to a graph.
func (graph *Graph[T]) AddEdge(sourceNode *Node[T], dependency Dependency) {
destinationNodes := graph.Tree.Get(dependency.NameFQDN)
if destinationNodes == nil {
return
}
for _, destinationNode := range destinationNodes {
if sourceNode == destinationNode {
continue
}
if sourceNode.Edges.Contains(destinationNode, OutgoingEdge) {
continue
}
sourceNode.Edges = append(sourceNode.Edges, Edge[T]{
Dependency: dependency,
Node: destinationNode,
Direction: OutgoingEdge,
})
destinationNode.Edges = append(destinationNode.Edges, Edge[T]{
Dependency: dependency,
Node: sourceNode,
Direction: IncomingEdge,
})
}
}
// RemoveNode removes a node from a graph.
func (nodes dnsGraphNodes[T]) RemoveNode(toRemove *Node[T]) dnsGraphNodes[T] {
var newNodes dnsGraphNodes[T]
for _, node := range nodes {
if node != toRemove {
newNodes = append(newNodes, node)
}
}
return newNodes
}
// RemoveNode removes a node from a graph.
func (edges Edges[T]) RemoveNode(toRemove *Node[T]) Edges[T] {
var newEdges Edges[T]
for _, edge := range edges {
if edge.Node != toRemove {
newEdges = append(newEdges, edge)
}
}
return newEdges
}
// Contains returns true if a node is in the graph AND is in that direction.
func (edges Edges[T]) Contains(toFind *Node[T], direction edgeDirection) bool {
for _, edge := range edges {
if edge.Node == toFind && edge.Direction == direction {
return true
}
}
return false
}