diff --git a/hscontrol/app.go b/hscontrol/app.go index 64d40ed1..acc94229 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -225,7 +225,7 @@ func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) { for range ticker.C { var removed []types.NodeID var changed []types.NodeID - if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + if err := h.db.Write(func(tx *gorm.DB) error { removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout) return nil @@ -263,7 +263,7 @@ func (h *Headscale) expireExpiredMachines(intervalMs int64) { var changed bool for range ticker.C { - if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + if err := h.db.Write(func(tx *gorm.DB) error { lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck) return nil @@ -452,6 +452,7 @@ func (h *Headscale) ensureUnixSocketIsAbsent() error { func (h *Headscale) createRouter(grpcMux *grpcRuntime.ServeMux) *mux.Router { router := mux.NewRouter() + router.Use(prometheusMiddleware) router.PathPrefix("/debug/pprof/").Handler(http.DefaultServeMux) router.HandleFunc(ts2021UpgradePath, h.NoiseUpgradeHandler).Methods(http.MethodPost) @@ -508,7 +509,7 @@ func (h *Headscale) Serve() error { // Fetch an initial DERP Map before we start serving h.DERPMap = derp.GetDERPMap(h.cfg.DERP) - h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier.ConnectedMap()) + h.mapper = mapper.NewMapper(h.db, h.cfg, h.DERPMap, h.nodeNotifier) if h.cfg.DERP.ServerEnabled { // When embedded DERP is enabled we always need a STUN server diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 8307d314..0679d72e 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -273,8 +273,6 @@ func (h *Headscale) handleAuthKey( Err(err). Msg("Cannot encode message") http.Error(writer, "Internal server error", http.StatusInternalServerError) - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). - Inc() return } @@ -294,13 +292,6 @@ func (h *Headscale) handleAuthKey( Str("node", registerRequest.Hostinfo.Hostname). Msg("Failed authentication via AuthKey") - if pak != nil { - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). - Inc() - } else { - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", "unknown").Inc() - } - return } @@ -404,15 +395,13 @@ func (h *Headscale) handleAuthKey( Caller(). Err(err). Msg("could not register node") - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). - Inc() http.Error(writer, "Internal server error", http.StatusInternalServerError) return } } - err = h.db.DB.Transaction(func(tx *gorm.DB) error { + h.db.Write(func(tx *gorm.DB) error { return db.UsePreAuthKey(tx, pak) }) if err != nil { @@ -420,8 +409,6 @@ func (h *Headscale) handleAuthKey( Caller(). Err(err). Msg("Failed to use pre-auth key") - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). - Inc() http.Error(writer, "Internal server error", http.StatusInternalServerError) return @@ -440,14 +427,10 @@ func (h *Headscale) handleAuthKey( Str("node", registerRequest.Hostinfo.Hostname). Err(err). Msg("Cannot encode message") - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "error", pak.User.Name). - Inc() http.Error(writer, "Internal server error", http.StatusInternalServerError) return } - nodeRegistrations.WithLabelValues("new", util.RegisterMethodAuthKey, "success", pak.User.Name). - Inc() writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) _, err = writer.Write(respBody) @@ -563,7 +546,7 @@ func (h *Headscale) handleNodeLogOut( } if node.IsEphemeral() { - changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.ConnectedMap()) + changedNodes, err := h.db.DeleteNode(&node, h.nodeNotifier.LikelyConnectedMap()) if err != nil { log.Error(). Err(err). @@ -616,14 +599,10 @@ func (h *Headscale) handleNodeWithValidRegistration( Caller(). Err(err). Msg("Cannot encode message") - nodeRegistrations.WithLabelValues("update", "web", "error", node.User.Name). - Inc() http.Error(writer, "Internal server error", http.StatusInternalServerError) return } - nodeRegistrations.WithLabelValues("update", "web", "success", node.User.Name). - Inc() writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) @@ -654,7 +633,7 @@ func (h *Headscale) handleNodeKeyRefresh( Str("node", node.Hostname). Msg("We have the OldNodeKey in the database. This is a key refresh") - err := h.db.DB.Transaction(func(tx *gorm.DB) error { + err := h.db.Write(func(tx *gorm.DB) error { return db.NodeSetNodeKey(tx, &node, registerRequest.NodeKey) }) if err != nil { @@ -737,14 +716,10 @@ func (h *Headscale) handleNodeExpiredOrLoggedOut( Caller(). Err(err). Msg("Cannot encode message") - nodeRegistrations.WithLabelValues("reauth", "web", "error", node.User.Name). - Inc() http.Error(writer, "Internal server error", http.StatusInternalServerError) return } - nodeRegistrations.WithLabelValues("reauth", "web", "success", node.User.Name). - Inc() writer.Header().Set("Content-Type", "application/json; charset=utf-8") writer.WriteHeader(http.StatusOK) diff --git a/hscontrol/auth_noise.go b/hscontrol/auth_noise.go index 323a49b0..6659dfa5 100644 --- a/hscontrol/auth_noise.go +++ b/hscontrol/auth_noise.go @@ -33,7 +33,6 @@ func (ns *noiseServer) NoiseRegistrationHandler( Caller(). Err(err). Msg("Cannot parse RegisterRequest") - nodeRegistrations.WithLabelValues("unknown", "web", "error", "unknown").Inc() http.Error(writer, "Internal error", http.StatusInternalServerError) return diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index 109fd610..91bf0cb3 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -10,6 +10,7 @@ import ( "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" "github.com/patrickmn/go-cache" + "github.com/puzpuzpuz/xsync/v3" "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/tailcfg" @@ -260,9 +261,9 @@ func NodeSetExpiry(tx *gorm.DB, return tx.Model(&types.Node{}).Where("id = ?", nodeID).Update("expiry", expiry).Error } -func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) { +func (hsdb *HSDatabase) DeleteNode(node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) { return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) { - return DeleteNode(tx, node, isConnected) + return DeleteNode(tx, node, isLikelyConnected) }) } @@ -270,9 +271,9 @@ func (hsdb *HSDatabase) DeleteNode(node *types.Node, isConnected types.NodeConne // Caller is responsible for notifying all of change. func DeleteNode(tx *gorm.DB, node *types.Node, - isConnected types.NodeConnectedMap, + isLikelyConnected *xsync.MapOf[types.NodeID, bool], ) ([]types.NodeID, error) { - changed, err := deleteNodeRoutes(tx, node, isConnected) + changed, err := deleteNodeRoutes(tx, node, isLikelyConnected) if err != nil { return changed, err } diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index 9ff02287..ce2ada33 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -11,6 +11,7 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "github.com/puzpuzpuz/xsync/v3" "gopkg.in/check.v1" "tailscale.com/tailcfg" "tailscale.com/types/key" @@ -120,7 +121,7 @@ func (s *Suite) TestHardDeleteNode(c *check.C) { } db.DB.Save(&node) - _, err = db.DeleteNode(&node, types.NodeConnectedMap{}) + _, err = db.DeleteNode(&node, xsync.NewMapOf[types.NodeID, bool]()) c.Assert(err, check.IsNil) _, err = db.getNode(user.Name, "testnode3") diff --git a/hscontrol/db/preauth_keys_test.go b/hscontrol/db/preauth_keys_test.go index 2cd59c40..fa9681ac 100644 --- a/hscontrol/db/preauth_keys_test.go +++ b/hscontrol/db/preauth_keys_test.go @@ -147,7 +147,7 @@ func (*Suite) TestEphemeralKeyReusable(c *check.C) { _, err = db.getNode("test7", "testest") c.Assert(err, check.IsNil) - db.DB.Transaction(func(tx *gorm.DB) error { + db.Write(func(tx *gorm.DB) error { DeleteExpiredEphemeralNodes(tx, time.Second*20) return nil }) @@ -181,7 +181,7 @@ func (*Suite) TestEphemeralKeyNotReusable(c *check.C) { _, err = db.getNode("test7", "testest") c.Assert(err, check.IsNil) - db.DB.Transaction(func(tx *gorm.DB) error { + db.Write(func(tx *gorm.DB) error { DeleteExpiredEphemeralNodes(tx, time.Second*20) return nil }) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index bc3f88a5..74b2b4b7 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -8,6 +8,7 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" + "github.com/puzpuzpuz/xsync/v3" "github.com/rs/zerolog/log" "gorm.io/gorm" "tailscale.com/util/set" @@ -126,7 +127,7 @@ func EnableRoute(tx *gorm.DB, id uint64) (*types.StateUpdate, error) { func DisableRoute(tx *gorm.DB, id uint64, - isConnected types.NodeConnectedMap, + isLikelyConnected *xsync.MapOf[types.NodeID, bool], ) ([]types.NodeID, error) { route, err := GetRoute(tx, id) if err != nil { @@ -147,7 +148,7 @@ func DisableRoute(tx *gorm.DB, return nil, err } - update, err = failoverRouteTx(tx, isConnected, route) + update, err = failoverRouteTx(tx, isLikelyConnected, route) if err != nil { return nil, err } @@ -182,17 +183,17 @@ func DisableRoute(tx *gorm.DB, func (hsdb *HSDatabase) DeleteRoute( id uint64, - isConnected types.NodeConnectedMap, + isLikelyConnected *xsync.MapOf[types.NodeID, bool], ) ([]types.NodeID, error) { return Write(hsdb.DB, func(tx *gorm.DB) ([]types.NodeID, error) { - return DeleteRoute(tx, id, isConnected) + return DeleteRoute(tx, id, isLikelyConnected) }) } func DeleteRoute( tx *gorm.DB, id uint64, - isConnected types.NodeConnectedMap, + isLikelyConnected *xsync.MapOf[types.NodeID, bool], ) ([]types.NodeID, error) { route, err := GetRoute(tx, id) if err != nil { @@ -207,7 +208,7 @@ func DeleteRoute( // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 var update []types.NodeID if !route.IsExitRoute() { - update, err = failoverRouteTx(tx, isConnected, route) + update, err = failoverRouteTx(tx, isLikelyConnected, route) if err != nil { return nil, nil } @@ -252,7 +253,7 @@ func DeleteRoute( return update, nil } -func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConnectedMap) ([]types.NodeID, error) { +func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isLikelyConnected *xsync.MapOf[types.NodeID, bool]) ([]types.NodeID, error) { routes, err := GetNodeRoutes(tx, node) if err != nil { return nil, fmt.Errorf("getting node routes: %w", err) @@ -266,7 +267,7 @@ func deleteNodeRoutes(tx *gorm.DB, node *types.Node, isConnected types.NodeConne // TODO(kradalby): This is a bit too aggressive, we could probably // figure out which routes needs to be failed over rather than all. - chn, err := failoverRouteTx(tx, isConnected, &routes[i]) + chn, err := failoverRouteTx(tx, isLikelyConnected, &routes[i]) if err != nil { return changed, fmt.Errorf("failing over route after delete: %w", err) } @@ -409,7 +410,7 @@ func SaveNodeRoutes(tx *gorm.DB, node *types.Node) (bool, error) { // If needed, the failover will be attempted. func FailoverNodeRoutesIfNeccessary( tx *gorm.DB, - isConnected types.NodeConnectedMap, + isLikelyConnected *xsync.MapOf[types.NodeID, bool], node *types.Node, ) (*types.StateUpdate, error) { nodeRoutes, err := GetNodeRoutes(tx, node) @@ -430,12 +431,12 @@ nodeRouteLoop: if route.IsPrimary { // if we have a primary route, and the node is connected // nothing needs to be done. - if conn, ok := isConnected[route.Node.ID]; conn && ok { + if val, ok := isLikelyConnected.Load(route.Node.ID); ok && val { continue nodeRouteLoop } // if not, we need to failover the route - failover := failoverRoute(isConnected, &route, routes) + failover := failoverRoute(isLikelyConnected, &route, routes) if failover != nil { err := failover.save(tx) if err != nil { @@ -477,7 +478,7 @@ nodeRouteLoop: // If the given route was not primary, it returns early. func failoverRouteTx( tx *gorm.DB, - isConnected types.NodeConnectedMap, + isLikelyConnected *xsync.MapOf[types.NodeID, bool], r *types.Route, ) ([]types.NodeID, error) { if r == nil { @@ -500,7 +501,7 @@ func failoverRouteTx( return nil, fmt.Errorf("getting routes by prefix: %w", err) } - fo := failoverRoute(isConnected, r, routes) + fo := failoverRoute(isLikelyConnected, r, routes) if fo == nil { return nil, nil } @@ -538,7 +539,7 @@ func (f *failover) save(tx *gorm.DB) error { } func failoverRoute( - isConnected types.NodeConnectedMap, + isLikelyConnected *xsync.MapOf[types.NodeID, bool], routeToReplace *types.Route, altRoutes types.Routes, @@ -570,9 +571,11 @@ func failoverRoute( continue } - if isConnected != nil && isConnected[route.Node.ID] { - newPrimary = &altRoutes[idx] - break + if isLikelyConnected != nil { + if val, ok := isLikelyConnected.Load(route.Node.ID); ok && val { + newPrimary = &altRoutes[idx] + break + } } } diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 453a7503..02342ca2 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -10,11 +10,22 @@ import ( "github.com/google/go-cmp/cmp/cmpopts" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "github.com/puzpuzpuz/xsync/v3" "gopkg.in/check.v1" "gorm.io/gorm" "tailscale.com/tailcfg" ) +var smap = func(m map[types.NodeID]bool) *xsync.MapOf[types.NodeID, bool] { + s := xsync.NewMapOf[types.NodeID, bool]() + + for k, v := range m { + s.Store(k, v) + } + + return s +} + func (s *Suite) TestGetRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) @@ -331,7 +342,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { name string nodes types.Nodes routes types.Routes - isConnected []types.NodeConnectedMap + isConnected []map[types.NodeID]bool want []*types.StateUpdate wantErr bool }{ @@ -346,7 +357,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { r(1, 1, ipp("10.0.0.0/24"), true, true), r(2, 2, ipp("10.0.0.0/24"), true, false), }, - isConnected: []types.NodeConnectedMap{ + isConnected: []map[types.NodeID]bool{ // n1 goes down { 1: false, @@ -384,7 +395,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { r(1, 1, ipp("10.0.0.0/24"), true, true), r(2, 2, ipp("10.0.0.0/24"), true, false), }, - isConnected: []types.NodeConnectedMap{ + isConnected: []map[types.NodeID]bool{ // n1 up recon = noop { 1: true, @@ -428,7 +439,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { r(2, 2, ipp("10.0.0.0/24"), true, false), r(3, 3, ipp("10.0.0.0/24"), true, false), }, - isConnected: []types.NodeConnectedMap{ + isConnected: []map[types.NodeID]bool{ // n1 goes down { 1: false, @@ -486,7 +497,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { r(2, 2, ipp("10.0.0.0/24"), false, false), r(3, 3, ipp("10.0.0.0/24"), true, false), }, - isConnected: []types.NodeConnectedMap{ + isConnected: []map[types.NodeID]bool{ // n1 goes down { 1: false, @@ -516,7 +527,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { r(2, 2, ipp("10.0.0.0/24"), true, false), r(3, 3, ipp("10.1.0.0/24"), true, false), }, - isConnected: []types.NodeConnectedMap{ + isConnected: []map[types.NodeID]bool{ // n1 goes down { 1: false, @@ -539,7 +550,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { r(2, 2, ipp("10.0.0.0/24"), true, false), r(3, 3, ipp("10.1.0.0/24"), false, false), }, - isConnected: []types.NodeConnectedMap{ + isConnected: []map[types.NodeID]bool{ // n1 goes down { 1: false, @@ -562,7 +573,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { r(2, 2, ipp("10.0.0.0/24"), true, false), r(3, 3, ipp("10.1.0.0/24"), true, false), }, - isConnected: []types.NodeConnectedMap{ + isConnected: []map[types.NodeID]bool{ // n1 goes down { 1: false, @@ -585,7 +596,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { r(2, 2, ipp("10.0.0.0/24"), true, true), r(3, 3, ipp("10.1.0.0/24"), true, false), }, - isConnected: []types.NodeConnectedMap{ + isConnected: []map[types.NodeID]bool{ // n1 goes down { 1: true, @@ -618,7 +629,7 @@ func TestFailoverNodeRoutesIfNeccessary(t *testing.T) { want := tt.want[step] got, err := Write(db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { - return FailoverNodeRoutesIfNeccessary(tx, isConnected, node) + return FailoverNodeRoutesIfNeccessary(tx, smap(isConnected), node) }) if (err != nil) != tt.wantErr { @@ -640,7 +651,7 @@ func TestFailoverRouteTx(t *testing.T) { name string failingRoute types.Route routes types.Routes - isConnected types.NodeConnectedMap + isConnected map[types.NodeID]bool want []types.NodeID wantErr bool }{ @@ -743,7 +754,7 @@ func TestFailoverRouteTx(t *testing.T) { Enabled: true, }, }, - isConnected: types.NodeConnectedMap{ + isConnected: map[types.NodeID]bool{ 1: false, 2: true, }, @@ -841,7 +852,7 @@ func TestFailoverRouteTx(t *testing.T) { Enabled: true, }, }, - isConnected: types.NodeConnectedMap{ + isConnected: map[types.NodeID]bool{ 1: true, 2: true, 3: true, @@ -889,7 +900,7 @@ func TestFailoverRouteTx(t *testing.T) { Enabled: true, }, }, - isConnected: types.NodeConnectedMap{ + isConnected: map[types.NodeID]bool{ 1: true, 4: false, }, @@ -945,7 +956,7 @@ func TestFailoverRouteTx(t *testing.T) { Enabled: true, }, }, - isConnected: types.NodeConnectedMap{ + isConnected: map[types.NodeID]bool{ 1: false, 2: true, 4: false, @@ -1010,7 +1021,7 @@ func TestFailoverRouteTx(t *testing.T) { } got, err := Write(db.DB, func(tx *gorm.DB) ([]types.NodeID, error) { - return failoverRouteTx(tx, tt.isConnected, &tt.failingRoute) + return failoverRouteTx(tx, smap(tt.isConnected), &tt.failingRoute) }) if (err != nil) != tt.wantErr { @@ -1048,7 +1059,7 @@ func TestFailoverRoute(t *testing.T) { name string failingRoute types.Route routes types.Routes - isConnected types.NodeConnectedMap + isConnected map[types.NodeID]bool want *failover }{ { @@ -1085,7 +1096,7 @@ func TestFailoverRoute(t *testing.T) { r(1, 1, ipp("10.0.0.0/24"), true, true), r(2, 2, ipp("10.0.0.0/24"), true, false), }, - isConnected: types.NodeConnectedMap{ + isConnected: map[types.NodeID]bool{ 1: false, 2: true, }, @@ -1111,7 +1122,7 @@ func TestFailoverRoute(t *testing.T) { r(2, 2, ipp("10.0.0.0/24"), true, true), r(3, 3, ipp("10.0.0.0/24"), true, false), }, - isConnected: types.NodeConnectedMap{ + isConnected: map[types.NodeID]bool{ 1: true, 2: true, 3: true, @@ -1128,7 +1139,7 @@ func TestFailoverRoute(t *testing.T) { r(1, 1, ipp("10.0.0.0/24"), true, true), r(2, 4, ipp("10.0.0.0/24"), true, false), }, - isConnected: types.NodeConnectedMap{ + isConnected: map[types.NodeID]bool{ 1: true, 4: false, }, @@ -1142,7 +1153,7 @@ func TestFailoverRoute(t *testing.T) { r(2, 4, ipp("10.0.0.0/24"), true, false), r(3, 2, ipp("10.0.0.0/24"), true, false), }, - isConnected: types.NodeConnectedMap{ + isConnected: map[types.NodeID]bool{ 1: false, 2: true, 4: false, @@ -1172,7 +1183,7 @@ func TestFailoverRoute(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotf := failoverRoute(tt.isConnected, &tt.failingRoute, tt.routes) + gotf := failoverRoute(smap(tt.isConnected), &tt.failingRoute, tt.routes) if tt.want == nil && gotf != nil { t.Fatalf("expected nil, got %+v", gotf) diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index a24dcead..41be5e9d 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -145,7 +145,7 @@ func (api headscaleV1APIServer) ExpirePreAuthKey( ctx context.Context, request *v1.ExpirePreAuthKeyRequest, ) (*v1.ExpirePreAuthKeyResponse, error) { - err := api.h.db.DB.Transaction(func(tx *gorm.DB) error { + err := api.h.db.Write(func(tx *gorm.DB) error { preAuthKey, err := db.GetPreAuthKey(tx, request.GetUser(), request.Key) if err != nil { return err @@ -301,7 +301,7 @@ func (api headscaleV1APIServer) DeleteNode( changedNodes, err := api.h.db.DeleteNode( node, - api.h.nodeNotifier.ConnectedMap(), + api.h.nodeNotifier.LikelyConnectedMap(), ) if err != nil { return nil, err @@ -343,7 +343,7 @@ func (api headscaleV1APIServer) ExpireNode( } ctx = types.NotifyCtx(ctx, "cli-expirenode-self", node.Hostname) - api.h.nodeNotifier.NotifyByMachineKey( + api.h.nodeNotifier.NotifyByNodeID( ctx, types.StateUpdate{ Type: types.StateSelfUpdate, @@ -401,7 +401,7 @@ func (api headscaleV1APIServer) ListNodes( ctx context.Context, request *v1.ListNodesRequest, ) (*v1.ListNodesResponse, error) { - isConnected := api.h.nodeNotifier.ConnectedMap() + isLikelyConnected := api.h.nodeNotifier.LikelyConnectedMap() if request.GetUser() != "" { nodes, err := db.Read(api.h.db.DB, func(rx *gorm.DB) (types.Nodes, error) { return db.ListNodesByUser(rx, request.GetUser()) @@ -416,7 +416,9 @@ func (api headscaleV1APIServer) ListNodes( // Populate the online field based on // currently connected nodes. - resp.Online = isConnected[node.ID] + if val, ok := isLikelyConnected.Load(node.ID); ok && val { + resp.Online = true + } response[index] = resp } @@ -439,7 +441,9 @@ func (api headscaleV1APIServer) ListNodes( // Populate the online field based on // currently connected nodes. - resp.Online = isConnected[node.ID] + if val, ok := isLikelyConnected.Load(node.ID); ok && val { + resp.Online = true + } validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( node, @@ -528,7 +532,7 @@ func (api headscaleV1APIServer) DisableRoute( request *v1.DisableRouteRequest, ) (*v1.DisableRouteResponse, error) { update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) { - return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.ConnectedMap()) + return db.DisableRoute(tx, request.GetRouteId(), api.h.nodeNotifier.LikelyConnectedMap()) }) if err != nil { return nil, err @@ -568,7 +572,7 @@ func (api headscaleV1APIServer) DeleteRoute( ctx context.Context, request *v1.DeleteRouteRequest, ) (*v1.DeleteRouteResponse, error) { - isConnected := api.h.nodeNotifier.ConnectedMap() + isConnected := api.h.nodeNotifier.LikelyConnectedMap() update, err := db.Write(api.h.db.DB, func(tx *gorm.DB) ([]types.NodeID, error) { return db.DeleteRoute(tx, request.GetRouteId(), isConnected) }) diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index fe8af4d3..d4f4392a 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -17,6 +17,7 @@ import ( mapset "github.com/deckarep/golang-set/v2" "github.com/juanfont/headscale/hscontrol/db" + "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" @@ -51,10 +52,10 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_ type Mapper struct { // Configuration // TODO(kradalby): figure out if this is the format we want this in - db *db.HSDatabase - cfg *types.Config - derpMap *tailcfg.DERPMap - isLikelyConnected types.NodeConnectedMap + db *db.HSDatabase + cfg *types.Config + derpMap *tailcfg.DERPMap + notif *notifier.Notifier uid string created time.Time @@ -70,15 +71,15 @@ func NewMapper( db *db.HSDatabase, cfg *types.Config, derpMap *tailcfg.DERPMap, - isLikelyConnected types.NodeConnectedMap, + notif *notifier.Notifier, ) *Mapper { uid, _ := util.GenerateRandomStringDNSSafe(mapperIDLength) return &Mapper{ - db: db, - cfg: cfg, - derpMap: derpMap, - isLikelyConnected: isLikelyConnected, + db: db, + cfg: cfg, + derpMap: derpMap, + notif: notif, uid: uid, created: time.Now(), @@ -517,7 +518,7 @@ func (m *Mapper) ListPeers(nodeID types.NodeID) (types.Nodes, error) { } for _, peer := range peers { - online := m.isLikelyConnected[peer.ID] + online := m.notif.IsLikelyConnected(peer.ID) peer.IsOnline = &online } diff --git a/hscontrol/metrics.go b/hscontrol/metrics.go index fc56f584..9d802caf 100644 --- a/hscontrol/metrics.go +++ b/hscontrol/metrics.go @@ -1,6 +1,10 @@ package hscontrol import ( + "net/http" + "strconv" + + "github.com/gorilla/mux" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" ) @@ -8,18 +12,94 @@ import ( const prometheusNamespace = "headscale" var ( - // This is a high cardinality metric (user x node), we might want to make this - // configurable/opt-in in the future. - nodeRegistrations = promauto.NewCounterVec(prometheus.CounterOpts{ + mapResponseSent = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: prometheusNamespace, - Name: "node_registrations_total", - Help: "The total amount of registered node attempts", - }, []string{"action", "auth", "status", "user"}) - - updateRequestsSentToNode = promauto.NewCounterVec(prometheus.CounterOpts{ + Name: "mapresponse_sent_total", + Help: "total count of mapresponses sent to clients", + }, []string{"status", "type"}) + mapResponseUpdateReceived = promauto.NewCounterVec(prometheus.CounterOpts{ Namespace: prometheusNamespace, - Name: "update_request_sent_to_node_total", - Help: "The number of calls/messages issued on a specific nodes update channel", - }, []string{"user", "node", "status"}) - // TODO(kradalby): This is very debugging, we might want to remove it. + Name: "mapresponse_updates_received_total", + Help: "total count of mapresponse updates received on update channel", + }, []string{"type"}) + mapResponseWriteUpdatesInStream = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "mapresponse_write_updates_in_stream_total", + Help: "total count of writes that occured in a stream session, pre-68 nodes", + }, []string{"status"}) + mapResponseEndpointUpdates = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "mapresponse_endpoint_updates_total", + Help: "total count of endpoint updates received", + }, []string{"status"}) + mapResponseReadOnly = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "mapresponse_readonly_requests_total", + Help: "total count of readonly requests received", + }, []string{"status"}) + mapResponseSessions = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: prometheusNamespace, + Name: "mapresponse_current_sessions_total", + Help: "total count open map response sessions", + }) + mapResponseRejected = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "mapresponse_rejected_new_sessions_total", + Help: "total count of new mapsessions rejected", + }, []string{"reason"}) + httpDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "http_duration_seconds", + Help: "Duration of HTTP requests.", + }, []string{"path"}) + httpCounter = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "http_requests_total", + Help: "Total number of http requests processed", + }, []string{"code", "method", "path"}, + ) ) + +// prometheusMiddleware implements mux.MiddlewareFunc. +func prometheusMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + route := mux.CurrentRoute(r) + path, _ := route.GetPathTemplate() + + // Ignore streaming and noise sessions + // it has its own router further down. + if path == "/ts2021" || path == "/machine/map" || path == "/derp" || path == "/derp/probe" || path == "/bootstrap-dns" { + next.ServeHTTP(w, r) + return + } + + rw := &respWriterProm{ResponseWriter: w} + + timer := prometheus.NewTimer(httpDuration.WithLabelValues(path)) + next.ServeHTTP(rw, r) + timer.ObserveDuration() + httpCounter.WithLabelValues(strconv.Itoa(rw.status), r.Method, path).Inc() + }) +} + +type respWriterProm struct { + http.ResponseWriter + status int + written int64 + wroteHeader bool +} + +func (r *respWriterProm) WriteHeader(code int) { + r.status = code + r.wroteHeader = true + r.ResponseWriter.WriteHeader(code) +} + +func (r *respWriterProm) Write(b []byte) (int, error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + n, err := r.ResponseWriter.Write(b) + r.written += int64(n) + return n, err +} diff --git a/hscontrol/noise.go b/hscontrol/noise.go index 92a89d0f..9ddf2c85 100644 --- a/hscontrol/noise.go +++ b/hscontrol/noise.go @@ -95,6 +95,7 @@ func (h *Headscale) NoiseUpgradeHandler( // The HTTP2 server that exposes this router is created for // a single hijacked connection from /ts2021, using netutil.NewOneConnListener router := mux.NewRouter() + router.Use(prometheusMiddleware) router.HandleFunc("/machine/register", noiseServer.NoiseRegistrationHandler). Methods(http.MethodPost) @@ -267,10 +268,12 @@ func (ns *noiseServer) NoisePollNetMapHandler( defer ns.headscale.mapSessionMu.Unlock() sess.infof("node has an open stream(%p), rejecting new stream", sess) + mapResponseRejected.WithLabelValues("exists").Inc() return } ns.headscale.mapSessions[node.ID] = sess + mapResponseSessions.Inc() ns.headscale.mapSessionMu.Unlock() sess.tracef("releasing lock to check stream") } @@ -283,6 +286,7 @@ func (ns *noiseServer) NoisePollNetMapHandler( defer ns.headscale.mapSessionMu.Unlock() delete(ns.headscale.mapSessions, node.ID) + mapResponseSessions.Dec() sess.tracef("releasing lock to remove stream") } diff --git a/hscontrol/notifier/metrics.go b/hscontrol/notifier/metrics.go new file mode 100644 index 00000000..c461d379 --- /dev/null +++ b/hscontrol/notifier/metrics.go @@ -0,0 +1,27 @@ +package notifier + +import ( + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +const prometheusNamespace = "headscale" + +var ( + notifierWaitForLock = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: prometheusNamespace, + Name: "notifier_wait_for_lock_seconds", + Help: "histogram of time spent waiting for the notifier lock", + Buckets: []float64{0.001, 0.01, 0.1, 0.3, 0.5, 1, 3, 5, 10}, + }, []string{"action"}) + notifierUpdateSent = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: prometheusNamespace, + Name: "notifier_update_sent_total", + Help: "total count of update sent on nodes channel", + }, []string{"status", "type"}) + notifierNodeUpdateChans = promauto.NewGauge(prometheus.GaugeOpts{ + Namespace: prometheusNamespace, + Name: "notifier_open_channels_total", + Help: "total count open channels in notifier", + }) +) diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index 6c34af57..4ad58723 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -6,21 +6,23 @@ import ( "slices" "strings" "sync" + "time" "github.com/juanfont/headscale/hscontrol/types" + "github.com/puzpuzpuz/xsync/v3" "github.com/rs/zerolog/log" ) type Notifier struct { l sync.RWMutex nodes map[types.NodeID]chan<- types.StateUpdate - connected types.NodeConnectedMap + connected *xsync.MapOf[types.NodeID, bool] } func NewNotifier() *Notifier { return &Notifier{ nodes: make(map[types.NodeID]chan<- types.StateUpdate), - connected: make(types.NodeConnectedMap), + connected: xsync.NewMapOf[types.NodeID, bool](), } } @@ -31,16 +33,19 @@ func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) { Uint64("node.id", nodeID.Uint64()). Msg("releasing lock to add node") + start := time.Now() n.l.Lock() defer n.l.Unlock() + notifierWaitForLock.WithLabelValues("add").Observe(time.Since(start).Seconds()) n.nodes[nodeID] = c - n.connected[nodeID] = true + n.connected.Store(nodeID, true) log.Trace(). Uint64("node.id", nodeID.Uint64()). Int("open_chans", len(n.nodes)). Msg("Added new channel") + notifierNodeUpdateChans.Inc() } func (n *Notifier) RemoveNode(nodeID types.NodeID) { @@ -50,20 +55,23 @@ func (n *Notifier) RemoveNode(nodeID types.NodeID) { Uint64("node.id", nodeID.Uint64()). Msg("releasing lock to remove node") + start := time.Now() n.l.Lock() defer n.l.Unlock() + notifierWaitForLock.WithLabelValues("remove").Observe(time.Since(start).Seconds()) if len(n.nodes) == 0 { return } delete(n.nodes, nodeID) - n.connected[nodeID] = false + n.connected.Store(nodeID, false) log.Trace(). Uint64("node.id", nodeID.Uint64()). Int("open_chans", len(n.nodes)). Msg("Removed channel") + notifierNodeUpdateChans.Dec() } // IsConnected reports if a node is connected to headscale and has a @@ -72,17 +80,22 @@ func (n *Notifier) IsConnected(nodeID types.NodeID) bool { n.l.RLock() defer n.l.RUnlock() - return n.connected[nodeID] + if val, ok := n.connected.Load(nodeID); ok { + return val + } + return false } // IsLikelyConnected reports if a node is connected to headscale and has a // poll session open, but doesnt lock, so might be wrong. func (n *Notifier) IsLikelyConnected(nodeID types.NodeID) bool { - return n.connected[nodeID] + if val, ok := n.connected.Load(nodeID); ok { + return val + } + return false } -// TODO(kradalby): This returns a pointer and can be dangerous. -func (n *Notifier) ConnectedMap() types.NodeConnectedMap { +func (n *Notifier) LikelyConnectedMap() *xsync.MapOf[types.NodeID, bool] { return n.connected } @@ -95,45 +108,16 @@ func (n *Notifier) NotifyWithIgnore( update types.StateUpdate, ignoreNodeIDs ...types.NodeID, ) { - log.Trace().Caller().Str("type", update.Type.String()).Msg("acquiring lock to notify") - defer log.Trace(). - Caller(). - Str("type", update.Type.String()). - Msg("releasing lock, finished notifying") - - n.l.RLock() - defer n.l.RUnlock() - - if update.Type == types.StatePeerChangedPatch { - log.Trace().Interface("update", update).Interface("online", n.connected).Msg("PATCH UPDATE SENT") - } - - for nodeID, c := range n.nodes { + for nodeID := range n.nodes { if slices.Contains(ignoreNodeIDs, nodeID) { continue } - select { - case <-ctx.Done(): - log.Error(). - Err(ctx.Err()). - Uint64("node.id", nodeID.Uint64()). - Any("origin", ctx.Value("origin")). - Any("origin-hostname", ctx.Value("hostname")). - Msgf("update not sent, context cancelled") - - return - case c <- update: - log.Trace(). - Uint64("node.id", nodeID.Uint64()). - Any("origin", ctx.Value("origin")). - Any("origin-hostname", ctx.Value("hostname")). - Msgf("update successfully sent on chan") - } + n.NotifyByNodeID(ctx, update, nodeID) } } -func (n *Notifier) NotifyByMachineKey( +func (n *Notifier) NotifyByNodeID( ctx context.Context, update types.StateUpdate, nodeID types.NodeID, @@ -144,8 +128,10 @@ func (n *Notifier) NotifyByMachineKey( Str("type", update.Type.String()). Msg("releasing lock, finished notifying") + start := time.Now() n.l.RLock() defer n.l.RUnlock() + notifierWaitForLock.WithLabelValues("notify").Observe(time.Since(start).Seconds()) if c, ok := n.nodes[nodeID]; ok { select { @@ -156,6 +142,7 @@ func (n *Notifier) NotifyByMachineKey( Any("origin", ctx.Value("origin")). Any("origin-hostname", ctx.Value("hostname")). Msgf("update not sent, context cancelled") + notifierUpdateSent.WithLabelValues("cancelled", update.Type.String()).Inc() return case c <- update: @@ -164,6 +151,7 @@ func (n *Notifier) NotifyByMachineKey( Any("origin", ctx.Value("origin")). Any("origin-hostname", ctx.Value("hostname")). Msgf("update successfully sent on chan") + notifierUpdateSent.WithLabelValues("ok", update.Type.String()).Inc() } } } @@ -182,9 +170,10 @@ func (n *Notifier) String() string { b.WriteString("\n") b.WriteString("connected:\n") - for k, v := range n.connected { + n.connected.Range(func(k types.NodeID, v bool) bool { fmt.Fprintf(&b, "\t%d: %t\n", k, v) - } + return true + }) return b.String() } diff --git a/hscontrol/oidc.go b/hscontrol/oidc.go index 0680ce2f..b728a6d0 100644 --- a/hscontrol/oidc.go +++ b/hscontrol/oidc.go @@ -602,7 +602,7 @@ func (h *Headscale) registerNodeForOIDCCallback( return err } - if err := h.db.DB.Transaction(func(tx *gorm.DB) error { + if err := h.db.Write(func(tx *gorm.DB) error { if _, err := db.RegisterNodeFromAuthCallback( // TODO(kradalby): find a better way to use the cache across modules tx, diff --git a/hscontrol/poll.go b/hscontrol/poll.go index c38c65e2..b903f122 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -64,7 +64,7 @@ func (h *Headscale) newMapSession( w http.ResponseWriter, node *types.Node, ) *mapSession { - warnf, tracef, infof, errf := logPollFunc(req, node) + warnf, infof, tracef, errf := logPollFunc(req, node) // Use a buffered channel in case a node is not fully ready // to receive a message to make sure we dont block the entire @@ -196,8 +196,10 @@ func (m *mapSession) serve() { // return err := m.handleSaveNode() if err != nil { + mapResponseWriteUpdatesInStream.WithLabelValues("error").Inc() return } + mapResponseWriteUpdatesInStream.WithLabelValues("ok").Inc() } // Set up the client stream @@ -284,6 +286,7 @@ func (m *mapSession) serve() { patches = filteredPatches } + updateType := "full" // When deciding what update to send, the following is considered, // Full is a superset of all updates, when a full update is requested, // send only that and move on, all other updates will be present in @@ -303,12 +306,15 @@ func (m *mapSession) serve() { } else if changed != nil { m.tracef(fmt.Sprintf("Sending Changed MapResponse: %v", lastMessage)) data, err = m.mapper.PeerChangedResponse(m.req, m.node, changed, patches, m.h.ACLPolicy, lastMessage) + updateType = "change" } else if patches != nil { m.tracef(fmt.Sprintf("Sending Changed Patch MapResponse: %v", lastMessage)) data, err = m.mapper.PeerChangedPatchResponse(m.req, m.node, patches, m.h.ACLPolicy) + updateType = "patch" } else if derp { m.tracef("Sending DERPUpdate MapResponse") data, err = m.mapper.DERPMapResponse(m.req, m.node, m.h.DERPMap) + updateType = "derp" } if err != nil { @@ -324,19 +330,22 @@ func (m *mapSession) serve() { startWrite := time.Now() _, err = m.w.Write(data) if err != nil { + mapResponseSent.WithLabelValues("error", updateType).Inc() m.errf(err, "Could not write the map response, for mapSession: %p", m) return } err = rc.Flush() if err != nil { + mapResponseSent.WithLabelValues("error", updateType).Inc() m.errf(err, "flushing the map response to client, for mapSession: %p", m) return } log.Trace().Str("node", m.node.Hostname).TimeDiff("timeSpent", time.Now(), startWrite).Str("mkey", m.node.MachineKey.String()).Msg("finished writing mapresp to node") - m.infof("update sent") + mapResponseSent.WithLabelValues("ok", updateType).Inc() + m.tracef("update sent") } // reset @@ -364,7 +373,8 @@ func (m *mapSession) serve() { // Consume all updates sent to node case update := <-m.ch: - m.tracef("received stream update: %d %s", update.Type, update.Message) + m.tracef("received stream update: %s %s", update.Type.String(), update.Message) + mapResponseUpdateReceived.WithLabelValues(update.Type.String()).Inc() switch update.Type { case types.StateFullUpdate: @@ -404,27 +414,30 @@ func (m *mapSession) serve() { data, err := m.mapper.KeepAliveResponse(m.req, m.node) if err != nil { m.errf(err, "Error generating the keep alive msg") - + mapResponseSent.WithLabelValues("error", "keepalive").Inc() return } _, err = m.w.Write(data) if err != nil { m.errf(err, "Cannot write keep alive message") - + mapResponseSent.WithLabelValues("error", "keepalive").Inc() return } err = rc.Flush() if err != nil { m.errf(err, "flushing keep alive to client, for mapSession: %p", m) + mapResponseSent.WithLabelValues("error", "keepalive").Inc() return } + + mapResponseSent.WithLabelValues("ok", "keepalive").Inc() } } } func (m *mapSession) pollFailoverRoutes(where string, node *types.Node) { update, err := db.Write(m.h.db.DB, func(tx *gorm.DB) (*types.StateUpdate, error) { - return db.FailoverNodeRoutesIfNeccessary(tx, m.h.nodeNotifier.ConnectedMap(), node) + return db.FailoverNodeRoutesIfNeccessary(tx, m.h.nodeNotifier.LikelyConnectedMap(), node) }) if err != nil { m.errf(err, fmt.Sprintf("failed to ensure failover routes, %s", where)) @@ -454,7 +467,7 @@ func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) { node.LastSeen = &now change.LastSeen = &now - err := h.db.DB.Transaction(func(tx *gorm.DB) error { + err := h.db.Write(func(tx *gorm.DB) error { return db.SetLastSeen(tx, node.ID, *node.LastSeen) }) if err != nil { @@ -501,6 +514,7 @@ func (m *mapSession) handleEndpointUpdate() { // If there is no changes and nothing to save, // return early. if peerChangeEmpty(change) && !sendUpdate { + mapResponseEndpointUpdates.WithLabelValues("noop").Inc() return } @@ -518,6 +532,7 @@ func (m *mapSession) handleEndpointUpdate() { if err != nil { m.errf(err, "Error processing node routes") http.Error(m.w, "", http.StatusInternalServerError) + mapResponseEndpointUpdates.WithLabelValues("error").Inc() return } @@ -527,6 +542,7 @@ func (m *mapSession) handleEndpointUpdate() { err := m.h.db.EnableAutoApprovedRoutes(m.h.ACLPolicy, m.node) if err != nil { m.errf(err, "Error running auto approved routes") + mapResponseEndpointUpdates.WithLabelValues("error").Inc() } } @@ -534,19 +550,19 @@ func (m *mapSession) handleEndpointUpdate() { // has an updated packetfilter allowing the new route // if it is defined in the ACL. ctx := types.NotifyCtx(context.Background(), "poll-nodeupdate-self-hostinfochange", m.node.Hostname) - m.h.nodeNotifier.NotifyByMachineKey( + m.h.nodeNotifier.NotifyByNodeID( ctx, types.StateUpdate{ Type: types.StateSelfUpdate, ChangeNodes: []types.NodeID{m.node.ID}, }, m.node.ID) - } if err := m.h.db.DB.Save(m.node).Error; err != nil { m.errf(err, "Failed to persist/update node in the database") http.Error(m.w, "", http.StatusInternalServerError) + mapResponseEndpointUpdates.WithLabelValues("error").Inc() return } @@ -562,6 +578,7 @@ func (m *mapSession) handleEndpointUpdate() { m.node.ID) m.w.WriteHeader(http.StatusOK) + mapResponseEndpointUpdates.WithLabelValues("ok").Inc() return } @@ -639,7 +656,7 @@ func (m *mapSession) handleReadOnlyRequest() { if err != nil { m.errf(err, "Failed to create MapResponse") http.Error(m.w, "", http.StatusInternalServerError) - + mapResponseReadOnly.WithLabelValues("error").Inc() return } @@ -648,9 +665,12 @@ func (m *mapSession) handleReadOnlyRequest() { _, err = m.w.Write(mapResp) if err != nil { m.errf(err, "Failed to write response") + mapResponseReadOnly.WithLabelValues("error").Inc() + return } m.w.WriteHeader(http.StatusOK) + mapResponseReadOnly.WithLabelValues("ok").Inc() return } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index 0e30bd9e..7f285924 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -28,7 +28,8 @@ var ( ) type NodeID uint64 -type NodeConnectedMap map[NodeID]bool + +// type NodeConnectedMap *xsync.MapOf[NodeID, bool] func (id NodeID) StableID() tailcfg.StableNodeID { return tailcfg.StableNodeID(strconv.FormatUint(uint64(id), util.Base10)) diff --git a/integration/acl_test.go b/integration/acl_test.go index 517e2dfb..9d763965 100644 --- a/integration/acl_test.go +++ b/integration/acl_test.go @@ -51,7 +51,7 @@ func aclScenario( clientsPerUser int, ) *Scenario { t.Helper() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) spec := map[string]int{ @@ -264,7 +264,7 @@ func TestACLHostsInNetMapTable(t *testing.T) { for name, testCase := range tests { t.Run(name, func(t *testing.T) { - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) spec := testCase.users diff --git a/integration/auth_oidc_test.go b/integration/auth_oidc_test.go index 347dbcc1..d24bf452 100644 --- a/integration/auth_oidc_test.go +++ b/integration/auth_oidc_test.go @@ -42,7 +42,7 @@ func TestOIDCAuthenticationPingAll(t *testing.T) { IntegrationSkip(t) t.Parallel() - baseScenario, err := NewScenario() + baseScenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) scenario := AuthOIDCScenario{ @@ -100,7 +100,7 @@ func TestOIDCExpireNodesBasedOnTokenExpiry(t *testing.T) { shortAccessTTL := 5 * time.Minute - baseScenario, err := NewScenario() + baseScenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) baseScenario.pool.MaxWait = 5 * time.Minute diff --git a/integration/auth_web_flow_test.go b/integration/auth_web_flow_test.go index 6d981bc1..8e121ca0 100644 --- a/integration/auth_web_flow_test.go +++ b/integration/auth_web_flow_test.go @@ -26,7 +26,7 @@ func TestAuthWebFlowAuthenticationPingAll(t *testing.T) { IntegrationSkip(t) t.Parallel() - baseScenario, err := NewScenario() + baseScenario, err := NewScenario(dockertestMaxWait()) if err != nil { t.Fatalf("failed to create scenario: %s", err) } @@ -67,7 +67,7 @@ func TestAuthWebFlowLogoutAndRelogin(t *testing.T) { IntegrationSkip(t) t.Parallel() - baseScenario, err := NewScenario() + baseScenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) scenario := AuthWebFlowScenario{ diff --git a/integration/cli_test.go b/integration/cli_test.go index af7b073b..24e3b19b 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -32,7 +32,7 @@ func TestUserCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -112,7 +112,7 @@ func TestPreAuthKeyCommand(t *testing.T) { user := "preauthkeyspace" count := 3 - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -254,7 +254,7 @@ func TestPreAuthKeyCommandWithoutExpiry(t *testing.T) { user := "pre-auth-key-without-exp-user" - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -317,7 +317,7 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { user := "pre-auth-key-reus-ephm-user" - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -394,7 +394,7 @@ func TestApiKeyCommand(t *testing.T) { count := 5 - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -562,7 +562,7 @@ func TestNodeTagCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -695,7 +695,7 @@ func TestNodeAdvertiseTagNoACLCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -745,7 +745,7 @@ func TestNodeAdvertiseTagWithACLCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -808,7 +808,7 @@ func TestNodeCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -1049,7 +1049,7 @@ func TestNodeExpireCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -1176,7 +1176,7 @@ func TestNodeRenameCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -1343,7 +1343,7 @@ func TestNodeMoveCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index b6a62e5f..39a9acca 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -23,7 +23,7 @@ func TestDERPServerScenario(t *testing.T) { IntegrationSkip(t) // t.Parallel() - baseScenario, err := NewScenario() + baseScenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) scenario := EmbeddedDERPServerScenario{ diff --git a/integration/general_test.go b/integration/general_test.go index ffd209d8..89e0d342 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -23,7 +23,7 @@ func TestPingAllByIP(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -67,7 +67,7 @@ func TestPingAllByIPPublicDERP(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -105,7 +105,7 @@ func TestAuthKeyLogoutAndRelogin(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -216,7 +216,7 @@ func TestEphemeral(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -299,7 +299,7 @@ func TestPingAllByHostname(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -348,7 +348,7 @@ func TestTaildrop(t *testing.T) { return err } - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -509,7 +509,7 @@ func TestResolveMagicDNS(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -577,7 +577,7 @@ func TestExpireNode(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -703,7 +703,7 @@ func TestNodeOnlineStatus(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -818,7 +818,7 @@ func TestPingAllByIPManyUpDown(t *testing.T) { IntegrationSkip(t) t.Parallel() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() diff --git a/integration/hsic/hsic.go b/integration/hsic/hsic.go index de4ec41f..0483213b 100644 --- a/integration/hsic/hsic.go +++ b/integration/hsic/hsic.go @@ -18,6 +18,7 @@ import ( "net/url" "os" "path" + "strconv" "strings" "time" @@ -201,6 +202,14 @@ func WithEmbeddedDERPServerOnly() Option { } } +// WithTuning allows changing the tuning settings easily. +func WithTuning(batchTimeout time.Duration, mapSessionChanSize int) Option { + return func(hsic *HeadscaleInContainer) { + hsic.env["HEADSCALE_TUNING_BATCH_CHANGE_DELAY"] = batchTimeout.String() + hsic.env["HEADSCALE_TUNING_NODE_MAPSESSION_BUFFERED_CHAN_SIZE"] = strconv.Itoa(mapSessionChanSize) + } +} + // New returns a new HeadscaleInContainer instance. func New( pool *dockertest.Pool, diff --git a/integration/route_test.go b/integration/route_test.go index 150dbd27..15ea22b1 100644 --- a/integration/route_test.go +++ b/integration/route_test.go @@ -28,7 +28,7 @@ func TestEnablingRoutes(t *testing.T) { user := "enable-routing" - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErrf(t, "failed to create scenario: %s", err) defer scenario.Shutdown() @@ -250,7 +250,7 @@ func TestHASubnetRouterFailover(t *testing.T) { user := "enable-routing" - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErrf(t, "failed to create scenario: %s", err) // defer scenario.Shutdown() @@ -822,7 +822,7 @@ func TestEnableDisableAutoApprovedRoute(t *testing.T) { user := "enable-disable-routing" - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErrf(t, "failed to create scenario: %s", err) defer scenario.Shutdown() @@ -966,7 +966,7 @@ func TestSubnetRouteACL(t *testing.T) { user := "subnet-route-acl" - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErrf(t, "failed to create scenario: %s", err) defer scenario.Shutdown() diff --git a/integration/scenario.go b/integration/scenario.go index 0ba44e7d..9444d882 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -8,6 +8,7 @@ import ( "os" "sort" "sync" + "time" v1 "github.com/juanfont/headscale/gen/go/headscale/v1" "github.com/juanfont/headscale/hscontrol/util" @@ -141,7 +142,7 @@ type Scenario struct { // NewScenario creates a test Scenario which can be used to bootstraps a ControlServer with // a set of Users and TailscaleClients. -func NewScenario() (*Scenario, error) { +func NewScenario(maxWait time.Duration) (*Scenario, error) { hash, err := util.GenerateRandomStringDNSSafe(scenarioHashLength) if err != nil { return nil, err @@ -152,7 +153,7 @@ func NewScenario() (*Scenario, error) { return nil, fmt.Errorf("could not connect to docker: %w", err) } - pool.MaxWait = dockertestMaxWait() + pool.MaxWait = maxWait networkName := fmt.Sprintf("hs-%s", hash) if overrideNetworkName := os.Getenv("HEADSCALE_TEST_NETWORK_NAME"); overrideNetworkName != "" { diff --git a/integration/scenario_test.go b/integration/scenario_test.go index cc9810a4..ea941ed7 100644 --- a/integration/scenario_test.go +++ b/integration/scenario_test.go @@ -33,7 +33,7 @@ func TestHeadscale(t *testing.T) { user := "test-space" - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -78,7 +78,7 @@ func TestCreateTailscale(t *testing.T) { user := "only-create-containers" - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() @@ -114,7 +114,7 @@ func TestTailscaleNodesJoiningHeadcale(t *testing.T) { count := 1 - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) defer scenario.Shutdown() diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 587190e4..6d053b0d 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -44,7 +44,7 @@ var retry = func(times int, sleepInterval time.Duration, func sshScenario(t *testing.T, policy *policy.ACLPolicy, clientsPerUser int) *Scenario { t.Helper() - scenario, err := NewScenario() + scenario, err := NewScenario(dockertestMaxWait()) assertNoErr(t, err) spec := map[string]int{