batch updates in notifier (#1905)

This commit is contained in:
Kristoffer Dalby 2024-04-27 10:47:39 +02:00 committed by GitHub
parent fef8261339
commit cb0b495ea9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 541 additions and 158 deletions

View File

@ -137,7 +137,7 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) {
noisePrivateKey: noisePrivateKey,
registrationCache: registrationCache,
pollNetMapStreamWG: sync.WaitGroup{},
nodeNotifier: notifier.NewNotifier(),
nodeNotifier: notifier.NewNotifier(cfg),
mapSessions: make(map[types.NodeID]*mapSession),
}

View File

@ -18,7 +18,12 @@ var (
Namespace: prometheusNamespace,
Name: "notifier_update_sent_total",
Help: "total count of update sent on nodes channel",
}, []string{"status", "type"})
}, []string{"status", "type", "trigger"})
notifierUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{
Namespace: prometheusNamespace,
Name: "notifier_update_received_total",
Help: "total count of updates received by notifier",
}, []string{"type", "trigger"})
notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{
Namespace: prometheusNamespace,
Name: "notifier_open_channels_total",

View File

@ -3,7 +3,7 @@ package notifier
import (
"context"
"fmt"
"slices"
"sort"
"strings"
"sync"
"time"
@ -11,19 +11,27 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/puzpuzpuz/xsync/v3"
"github.com/rs/zerolog/log"
"tailscale.com/tailcfg"
"tailscale.com/util/set"
)
type Notifier struct {
l sync.RWMutex
nodes map[types.NodeID]chan<- types.StateUpdate
connected *xsync.MapOf[types.NodeID, bool]
b *batcher
}
func NewNotifier() *Notifier {
return &Notifier{
func NewNotifier(cfg *types.Config) *Notifier {
n := &Notifier{
nodes: make(map[types.NodeID]chan<- types.StateUpdate),
connected: xsync.NewMapOf[types.NodeID, bool](),
}
b := newBatcher(cfg.Tuning.BatchChangeDelay, n)
n.b = b
// TODO(kradalby): clean this up
go b.doWork()
return n
}
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
@ -108,13 +116,8 @@ func (n *Notifier) NotifyWithIgnore(
update types.StateUpdate,
ignoreNodeIDs ...types.NodeID,
) {
for nodeID := range n.nodes {
if slices.Contains(ignoreNodeIDs, nodeID) {
continue
}
n.NotifyByNodeID(ctx, update, nodeID)
}
notifierUpdateReceived.WithLabelValues(update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
n.b.addOrPassthrough(update)
}
func (n *Notifier) NotifyByNodeID(
@ -139,10 +142,10 @@ func (n *Notifier) NotifyByNodeID(
log.Error().
Err(ctx.Err()).
Uint64("node.id", nodeID.Uint64()).
Any("origin", ctx.Value("origin")).
Any("origin-hostname", ctx.Value("hostname")).
Any("origin", types.NotifyOriginKey.Value(ctx)).
Any("origin-hostname", types.NotifyHostnameKey.Value(ctx)).
Msgf("update not sent, context cancelled")
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String()).Inc()
notifierUpdateSent.WithLabelValues("cancelled", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
return
case c <- update:
@ -151,11 +154,23 @@ func (n *Notifier) NotifyByNodeID(
Any("origin", ctx.Value("origin")).
Any("origin-hostname", ctx.Value("hostname")).
Msgf("update successfully sent on chan")
notifierUpdateSent.WithLabelValues("ok", update.Type.String()).Inc()
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), types.NotifyOriginKey.Value(ctx)).Inc()
}
}
}
func (n *Notifier) sendAll(update types.StateUpdate) {
start := time.Now()
n.l.RLock()
defer n.l.RUnlock()
notifierWaitForLock.WithLabelValues("send-all").Observe(time.Since(start).Seconds())
for _, c := range n.nodes {
c <- update
notifierUpdateSent.WithLabelValues("ok", update.Type.String(), "send-all").Inc()
}
}
func (n *Notifier) String() string {
n.l.RLock()
defer n.l.RUnlock()
@ -177,3 +192,166 @@ func (n *Notifier) String() string {
return b.String()
}
type batcher struct {
tick *time.Ticker
mu sync.Mutex
cancelCh chan struct{}
changedNodeIDs set.Slice[types.NodeID]
nodesChanged bool
patches map[types.NodeID]tailcfg.PeerChange
patchesChanged bool
n *Notifier
}
func newBatcher(batchTime time.Duration, n *Notifier) *batcher {
return &batcher{
tick: time.NewTicker(batchTime),
cancelCh: make(chan struct{}),
patches: make(map[types.NodeID]tailcfg.PeerChange),
n: n,
}
}
func (b *batcher) close() {
b.cancelCh <- struct{}{}
}
// addOrPassthrough adds the update to the batcher, if it is not a
// type that is currently batched, it will be sent immediately.
func (b *batcher) addOrPassthrough(update types.StateUpdate) {
b.mu.Lock()
defer b.mu.Unlock()
switch update.Type {
case types.StatePeerChanged:
b.changedNodeIDs.Add(update.ChangeNodes...)
b.nodesChanged = true
case types.StatePeerChangedPatch:
for _, newPatch := range update.ChangePatches {
if curr, ok := b.patches[types.NodeID(newPatch.NodeID)]; ok {
overwritePatch(&curr, newPatch)
b.patches[types.NodeID(newPatch.NodeID)] = curr
} else {
b.patches[types.NodeID(newPatch.NodeID)] = *newPatch
}
}
b.patchesChanged = true
default:
b.n.sendAll(update)
}
}
// flush sends all the accumulated patches to all
// nodes in the notifier.
func (b *batcher) flush() {
b.mu.Lock()
defer b.mu.Unlock()
if b.nodesChanged || b.patchesChanged {
var patches []*tailcfg.PeerChange
// If a node is getting a full update from a change
// node update, then the patch can be dropped.
for nodeID, patch := range b.patches {
if b.changedNodeIDs.Contains(nodeID) {
delete(b.patches, nodeID)
} else {
patches = append(patches, &patch)
}
}
changedNodes := b.changedNodeIDs.Slice().AsSlice()
sort.Slice(changedNodes, func(i, j int) bool {
return changedNodes[i] < changedNodes[j]
})
if b.changedNodeIDs.Slice().Len() > 0 {
update := types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changedNodes,
}
b.n.sendAll(update)
}
if len(patches) > 0 {
patchUpdate := types.StateUpdate{
Type: types.StatePeerChangedPatch,
ChangePatches: patches,
}
b.n.sendAll(patchUpdate)
}
b.changedNodeIDs = set.Slice[types.NodeID]{}
b.nodesChanged = false
b.patches = make(map[types.NodeID]tailcfg.PeerChange, len(b.patches))
b.patchesChanged = false
}
}
func (b *batcher) doWork() {
for {
select {
case <-b.cancelCh:
return
case <-b.tick.C:
b.flush()
}
}
}
// overwritePatch takes the current patch and a newer patch
// and override any field that has changed
func overwritePatch(currPatch, newPatch *tailcfg.PeerChange) {
if newPatch.DERPRegion != 0 {
currPatch.DERPRegion = newPatch.DERPRegion
}
if newPatch.Cap != 0 {
currPatch.Cap = newPatch.Cap
}
if newPatch.CapMap != nil {
currPatch.CapMap = newPatch.CapMap
}
if newPatch.Endpoints != nil {
currPatch.Endpoints = newPatch.Endpoints
}
if newPatch.Key != nil {
currPatch.Key = newPatch.Key
}
if newPatch.KeySignature != nil {
currPatch.KeySignature = newPatch.KeySignature
}
if newPatch.DiscoKey != nil {
currPatch.DiscoKey = newPatch.DiscoKey
}
if newPatch.Online != nil {
currPatch.Online = newPatch.Online
}
if newPatch.LastSeen != nil {
currPatch.LastSeen = newPatch.LastSeen
}
if newPatch.KeyExpiry != nil {
currPatch.KeyExpiry = newPatch.KeyExpiry
}
if newPatch.Capabilities != nil {
currPatch.Capabilities = newPatch.Capabilities
}
}

View File

@ -0,0 +1,249 @@
package notifier
import (
"context"
"net/netip"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"tailscale.com/tailcfg"
)
func TestBatcher(t *testing.T) {
tests := []struct {
name string
updates []types.StateUpdate
want []types.StateUpdate
}{
{
name: "full-passthrough",
updates: []types.StateUpdate{
{
Type: types.StateFullUpdate,
},
},
want: []types.StateUpdate{
{
Type: types.StateFullUpdate,
},
},
},
{
name: "derp-passthrough",
updates: []types.StateUpdate{
{
Type: types.StateDERPUpdated,
},
},
want: []types.StateUpdate{
{
Type: types.StateDERPUpdated,
},
},
},
{
name: "single-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2,
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2,
},
},
},
},
{
name: "merge-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 4,
},
},
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 3,
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChanged,
ChangeNodes: []types.NodeID{
2, 3, 4,
},
},
},
},
{
name: "single-patch-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
},
},
{
name: "merge-patch-to-same-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 5,
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 6,
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 2,
DERPRegion: 6,
},
},
},
},
},
{
name: "merge-patch-to-multiple-node-update",
updates: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
},
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
netip.MustParseAddrPort("2.2.2.2:8080"),
},
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 4,
DERPRegion: 6,
},
},
},
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 4,
Cap: tailcfg.CapabilityVersion(54),
},
},
},
},
want: []types.StateUpdate{
{
Type: types.StatePeerChangedPatch,
ChangePatches: []*tailcfg.PeerChange{
{
NodeID: 3,
Endpoints: []netip.AddrPort{
netip.MustParseAddrPort("1.1.1.1:9090"),
netip.MustParseAddrPort("2.2.2.2:8080"),
},
},
{
NodeID: 4,
DERPRegion: 6,
Cap: tailcfg.CapabilityVersion(54),
},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
n := NewNotifier(&types.Config{
Tuning: types.Tuning{
// We will call flush manually for the tests,
// so do not run the worker.
BatchChangeDelay: time.Hour,
},
})
ch := make(chan types.StateUpdate, 30)
defer close(ch)
n.AddNode(1, ch)
defer n.RemoveNode(1)
for _, u := range tt.updates {
n.NotifyAll(context.Background(), u)
}
n.b.flush()
var got []types.StateUpdate
for len(ch) > 0 {
out := <-ch
got = append(got, out)
}
if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" {
t.Errorf("batcher() unexpected result (-want +got):\n%s", diff)
}
})
}
}

View File

@ -66,10 +66,16 @@ func (h *Headscale) newMapSession(
) *mapSession {
warnf, infof, tracef, errf := logPollFunc(req, node)
// Use a buffered channel in case a node is not fully ready
// to receive a message to make sure we dont block the entire
// notifier.
updateChan := make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
var updateChan chan types.StateUpdate
if req.Stream {
// Use a buffered channel in case a node is not fully ready
// to receive a message to make sure we dont block the entire
// notifier.
updateChan = make(chan types.StateUpdate, h.cfg.Tuning.NodeMapSessionBufferedChanSize)
updateChan <- types.StateUpdate{
Type: types.StateFullUpdate,
}
}
return &mapSession{
h: h,
@ -218,33 +224,26 @@ func (m *mapSession) serve() {
ctx, cancel := context.WithCancel(context.WithValue(m.ctx, nodeNameContextKey, m.node.Hostname))
defer cancel()
// TODO(kradalby): Make this available through a tuning envvar
wait := time.Second
// Add a circuit breaker, if the loop is not interrupted
// inbetween listening for the channels, some updates
// might get stale and stucked in the "changed" map
// defined below.
blockBreaker := time.NewTicker(wait)
// true means changed, false means removed
var changed map[types.NodeID]bool
var patches []*tailcfg.PeerChange
var derp bool
// Set full to true to immediatly send a full mapresponse
full := true
prev := time.Now()
lastMessage := ""
// Loop through updates and continuously send them to the
// client.
for {
// If a full update has been requested or there are patches, then send it immediately
// otherwise wait for the "batching" of changes or patches
if full || patches != nil || (changed != nil && time.Since(prev) > wait) {
// consume channels with update, keep alives or "batch" blocking signals
select {
case <-m.cancelCh:
m.tracef("poll cancelled received")
return
case <-ctx.Done():
m.tracef("poll context done")
return
// Consume all updates sent to node
case update := <-m.ch:
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
var data []byte
var err error
var lastMessage string
// Ensure the node object is updated, for example, there
// might have been a hostinfo update in a sidechannel
@ -256,62 +255,43 @@ func (m *mapSession) serve() {
return
}
// If there are patches _and_ fully changed nodes, filter the
// patches and remove all patches that are present for the full
// changes updates. This allows us to send them as part of the
// PeerChange update, but only for nodes that are not fully changed.
// The fully changed nodes will be updated from the database and
// have all the updates needed.
// This means that the patches left are for nodes that has no
// updates that requires a full update.
// Patches are not suppose to be mixed in, but can be.
//
// From tailcfg docs:
// These are applied after Peers* above, but in practice the
// control server should only send these on their own, without
//
// Currently, there is no effort to merge patch updates, they
// are all sent, and the client will apply them in order.
// TODO(kradalby): Merge Patches for the same IDs to send less
// data and give the client less work.
if patches != nil && changed != nil {
var filteredPatches []*tailcfg.PeerChange
for _, patch := range patches {
if _, ok := changed[types.NodeID(patch.NodeID)]; !ok {
filteredPatches = append(filteredPatches, patch)
}
}
patches = filteredPatches
}
updateType := "full"
// When deciding what update to send, the following is considered,
// Full is a superset of all updates, when a full update is requested,
// send only that and move on, all other updates will be present in
// a full map response.
//
// If a map of changed nodes exists, prefer sending that as it will
// contain all the updates for the node, including patches, as it
// is fetched freshly from the database when building the response.
//
// If there is full changes registered, but we have patches for individual
// nodes, send them.
//
// Finally, if a DERP map is the only request, send that alone.
if full {
switch update.Type {
case types.StateFullUpdate:
m.tracef("Sending Full MapResponse")
data, err = m.mapper.FullMapResponse(m.req, m.node, m.h.ACLPolicy, fmt.Sprintf("from mapSession: %p, stream: %t", m, m.isStreaming()))
} else if changed != nil {
case types.StatePeerChanged:
changed := make(map[types.NodeID]bool, len(update.ChangeNodes))
for _, nodeID := range update.ChangeNodes {
changed[nodeID] = true
}
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, patches, m.h.ACLPolicy, lastMessage)
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
updateType = "change"
} else if patches != nil {
case types.StatePeerChangedPatch:
m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, patches, m.h.ACLPolicy)
data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, update.ChangePatches, m.h.ACLPolicy)
updateType = "patch"
} else if derp {
case types.StatePeerRemoved:
changed := make(map[types.NodeID]bool, len(update.Removed))
for _, nodeID := range update.Removed {
changed[nodeID] = false
}
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, update.ChangePatches, m.h.ACLPolicy, lastMessage)
updateType = "remove"
case types.StateSelfUpdate:
lastMessage = update.Message
m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage))
// create the map so an empty (self) update is sent
data, err = m.mapper.PeerChangedResponse(m.req, m.node, make(map[types.NodeID]bool), update.ChangePatches, m.h.ACLPolicy, lastMessage)
updateType = "remove"
case types.StateDERPUpdated:
m.tracef("Sending DERPUpdate MapResponse")
data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap)
updateType = "derp"
@ -348,68 +328,6 @@ func (m *mapSession) serve() {
m.tracef("update sent")
}
// reset
changed = nil
patches = nil
lastMessage = ""
full = false
derp = false
prev = time.Now()
}
// consume channels with update, keep alives or "batch" blocking signals
select {
case <-m.cancelCh:
m.tracef("poll cancelled received")
return
case <-ctx.Done():
m.tracef("poll context done")
return
// Avoid infinite block that would potentially leave
// some updates in the changed map.
case <-blockBreaker.C:
continue
// Consume all updates sent to node
case update := <-m.ch:
m.tracef("received stream update: %s %s", update.Type.String(), update.Message)
mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc()
switch update.Type {
case types.StateFullUpdate:
full = true
case types.StatePeerChanged:
if changed == nil {
changed = make(map[types.NodeID]bool)
}
for _, nodeID := range update.ChangeNodes {
changed[nodeID] = true
}
lastMessage = update.Message
case types.StatePeerChangedPatch:
patches = append(patches, update.ChangePatches...)
case types.StatePeerRemoved:
if changed == nil {
changed = make(map[types.NodeID]bool)
}
for _, nodeID := range update.Removed {
changed[nodeID] = false
}
case types.StateSelfUpdate:
// create the map so an empty (self) update is sent
if changed == nil {
changed = make(map[types.NodeID]bool)
}
lastMessage = update.Message
case types.StateDERPUpdated:
derp = true
}
case <-m.keepAliveTicker.C:
data, err := m.mapper.KeepAliveResponse(m.req, m.node)
if err != nil {

View File

@ -10,6 +10,7 @@ import (
"time"
"tailscale.com/tailcfg"
"tailscale.com/util/ctxkey"
)
const (
@ -183,10 +184,14 @@ func StateUpdateExpire(nodeID NodeID, expiry time.Time) StateUpdate {
}
}
var (
NotifyOriginKey = ctxkey.New("notify.origin", "")
NotifyHostnameKey = ctxkey.New("notify.hostname", "")
)
func NotifyCtx(ctx context.Context, origin, hostname string) context.Context {
ctx2, _ := context.WithTimeout(
context.WithValue(context.WithValue(ctx, "hostname", hostname), "origin", origin),
3*time.Second,
)
ctx2, _ := context.WithTimeout(ctx, 3*time.Second)
ctx2 = NotifyOriginKey.WithValue(ctx2, origin)
ctx2 = NotifyHostnameKey.WithValue(ctx2, hostname)
return ctx2
}

View File

@ -11,6 +11,7 @@ import (
"encoding/pem"
"errors"
"fmt"
"io"
"log"
"math/big"
"net"
@ -396,6 +397,14 @@ func (t *HeadscaleInContainer) Shutdown() error {
)
}
err = t.SaveMetrics("/tmp/control/metrics.txt")
if err != nil {
log.Printf(
"Failed to metrics from control: %s",
err,
)
}
// Send a interrupt signal to the "headscale" process inside the container
// allowing it to shut down gracefully and flush the profile to disk.
// The container will live for a bit longer due to the sleep at the end.
@ -448,6 +457,25 @@ func (t *HeadscaleInContainer) SaveLog(path string) error {
return dockertestutil.SaveLog(t.pool, t.container, path)
}
func (t *HeadscaleInContainer) SaveMetrics(savePath string) error {
resp, err := http.Get(fmt.Sprintf("http://%s:9090/metrics", t.hostname))
if err != nil {
return fmt.Errorf("getting metrics: %w", err)
}
defer resp.Body.Close()
out, err := os.Create(savePath)
if err != nil {
return fmt.Errorf("creating file for metrics: %w", err)
}
defer out.Close()
_, err = io.Copy(out, resp.Body)
if err != nil {
return fmt.Errorf("copy response to file: %w", err)
}
return nil
}
func (t *HeadscaleInContainer) SaveProfile(savePath string) error {
tarFile, err := t.FetchPath("/tmp/profile")
if err != nil {

View File

@ -252,7 +252,7 @@ func TestHASubnetRouterFailover(t *testing.T) {
scenario, err := NewScenario(dockertestMaxWait())
assertNoErrf(t, "failed to create scenario: %s", err)
// defer scenario.Shutdown()
defer scenario.Shutdown()
spec := map[string]int{
user: 3,