From f65f4eca35ba0987baba8628ac8ce7dc51982b65 Mon Sep 17 00:00:00 2001 From: Kristoffer Dalby Date: Sat, 9 Dec 2023 18:09:24 +0100 Subject: [PATCH] ensure online status and route changes are propagated (#1564) --- ...gration-v2-TestHASubnetRouterFailover.yaml | 67 ++ ...ation-v2-TestNodeOnlineLastSeenStatus.yaml | 67 ++ .gitignore | 1 + CHANGELOG.md | 2 + cmd/build-docker-img/main.go | 47 -- cmd/headscale/cli/nodes.go | 2 +- flake.nix | 2 +- go.mod | 2 +- hscontrol/app.go | 31 +- hscontrol/auth.go | 111 ++- hscontrol/db/db.go | 6 +- hscontrol/db/node.go | 188 +++-- hscontrol/db/node_test.go | 3 +- hscontrol/db/routes.go | 495 +++++++---- hscontrol/db/routes_test.go | 517 ++++++++---- hscontrol/derp/server/derp_server.go | 9 +- hscontrol/grpcv1.go | 29 +- hscontrol/mapper/mapper.go | 137 ++- hscontrol/mapper/mapper_test.go | 7 - hscontrol/mapper/tail.go | 16 +- hscontrol/mapper/tail_test.go | 2 - hscontrol/notifier/notifier.go | 28 + hscontrol/policy/acls_test.go | 31 +- hscontrol/poll.go | 367 +++++--- hscontrol/types/common.go | 50 +- hscontrol/types/node.go | 136 ++- hscontrol/types/node_test.go | 227 +++++ hscontrol/types/routes.go | 34 +- hscontrol/types/routes_test.go | 94 +++ hscontrol/util/test.go | 32 + integration/cli_test.go | 144 +--- integration/embedded_derp_test.go | 3 + integration/general_test.go | 224 ++++- integration/route_test.go | 780 ++++++++++++++++++ integration/run.sh | 2 +- integration/scenario.go | 19 +- integration/ssh_test.go | 10 +- integration/tailscale.go | 3 + integration/tsic/tsic.go | 91 +- integration/utils.go | 11 +- 40 files changed, 3170 insertions(+), 857 deletions(-) create mode 100644 .github/workflows/test-integration-v2-TestHASubnetRouterFailover.yaml create mode 100644 .github/workflows/test-integration-v2-TestNodeOnlineLastSeenStatus.yaml delete mode 100644 cmd/build-docker-img/main.go create mode 100644 hscontrol/types/routes_test.go create mode 100644 hscontrol/util/test.go create mode 100644 integration/route_test.go diff --git a/.github/workflows/test-integration-v2-TestHASubnetRouterFailover.yaml b/.github/workflows/test-integration-v2-TestHASubnetRouterFailover.yaml new file mode 100644 index 00000000..4ffe4640 --- /dev/null +++ b/.github/workflows/test-integration-v2-TestHASubnetRouterFailover.yaml @@ -0,0 +1,67 @@ +# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go +# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ + +name: Integration Test v2 - TestHASubnetRouterFailover + +on: [pull_request] + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + TestHASubnetRouterFailover: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - uses: DeterminateSystems/nix-installer-action@main + - uses: DeterminateSystems/magic-nix-cache-action@main + - uses: satackey/action-docker-layer-caching@main + continue-on-error: true + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v34 + with: + files: | + *.nix + go.* + **/*.go + integration_test/ + config-example.yaml + + - name: Run TestHASubnetRouterFailover + uses: Wandalen/wretry.action@master + if: steps.changed-files.outputs.any_changed == 'true' + with: + attempt_limit: 5 + command: | + nix develop --command -- docker run \ + --tty --rm \ + --volume ~/.cache/hs-integration-go:/go \ + --name headscale-test-suite \ + --volume $PWD:$PWD -w $PWD/integration \ + --volume /var/run/docker.sock:/var/run/docker.sock \ + --volume $PWD/control_logs:/tmp/control \ + golang:1 \ + go run gotest.tools/gotestsum@latest -- ./... \ + -failfast \ + -timeout 120m \ + -parallel 1 \ + -run "^TestHASubnetRouterFailover$" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: logs + path: "control_logs/*.log" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: pprof + path: "control_logs/*.pprof.tar" diff --git a/.github/workflows/test-integration-v2-TestNodeOnlineLastSeenStatus.yaml b/.github/workflows/test-integration-v2-TestNodeOnlineLastSeenStatus.yaml new file mode 100644 index 00000000..e3a30f83 --- /dev/null +++ b/.github/workflows/test-integration-v2-TestNodeOnlineLastSeenStatus.yaml @@ -0,0 +1,67 @@ +# DO NOT EDIT, generated with cmd/gh-action-integration-generator/main.go +# To regenerate, run "go generate" in cmd/gh-action-integration-generator/ + +name: Integration Test v2 - TestNodeOnlineLastSeenStatus + +on: [pull_request] + +concurrency: + group: ${{ github.workflow }}-$${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + TestNodeOnlineLastSeenStatus: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + with: + fetch-depth: 2 + + - uses: DeterminateSystems/nix-installer-action@main + - uses: DeterminateSystems/magic-nix-cache-action@main + - uses: satackey/action-docker-layer-caching@main + continue-on-error: true + + - name: Get changed files + id: changed-files + uses: tj-actions/changed-files@v34 + with: + files: | + *.nix + go.* + **/*.go + integration_test/ + config-example.yaml + + - name: Run TestNodeOnlineLastSeenStatus + uses: Wandalen/wretry.action@master + if: steps.changed-files.outputs.any_changed == 'true' + with: + attempt_limit: 5 + command: | + nix develop --command -- docker run \ + --tty --rm \ + --volume ~/.cache/hs-integration-go:/go \ + --name headscale-test-suite \ + --volume $PWD:$PWD -w $PWD/integration \ + --volume /var/run/docker.sock:/var/run/docker.sock \ + --volume $PWD/control_logs:/tmp/control \ + golang:1 \ + go run gotest.tools/gotestsum@latest -- ./... \ + -failfast \ + -timeout 120m \ + -parallel 1 \ + -run "^TestNodeOnlineLastSeenStatus$" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: logs + path: "control_logs/*.log" + + - uses: actions/upload-artifact@v3 + if: always() && steps.changed-files.outputs.any_changed == 'true' + with: + name: pprof + path: "control_logs/*.pprof.tar" diff --git a/.gitignore b/.gitignore index 3b85ecbb..f6e506bc 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ ignored/ tailscale/ +.vscode/ # Binaries for programs and plugins *.exe diff --git a/CHANGELOG.md b/CHANGELOG.md index 5157921d..47652cba 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ after improving the test harness as part of adopting [#1460](https://github.com/ - API: Machine is now Node [#1553](https://github.com/juanfont/headscale/pull/1553) - Remove support for older Tailscale clients [#1611](https://github.com/juanfont/headscale/pull/1611) - The latest supported client is 1.32 +- Headscale checks that _at least_ one DERP is defined at start [#1564](https://github.com/juanfont/headscale/pull/1564) + - If no DERP is configured, the server will fail to start, this can be because it cannot load the DERPMap from file or url. ### Changes diff --git a/cmd/build-docker-img/main.go b/cmd/build-docker-img/main.go deleted file mode 100644 index e162aa64..00000000 --- a/cmd/build-docker-img/main.go +++ /dev/null @@ -1,47 +0,0 @@ -package main - -import ( - "log" - - "github.com/juanfont/headscale/integration" - "github.com/juanfont/headscale/integration/tsic" - "github.com/ory/dockertest/v3" -) - -func main() { - log.Printf("creating docker pool") - pool, err := dockertest.NewPool("") - if err != nil { - log.Fatalf("could not connect to docker: %s", err) - } - - log.Printf("creating docker network") - network, err := pool.CreateNetwork("docker-integration-net") - if err != nil { - log.Fatalf("failed to create or get network: %s", err) - } - - for _, version := range integration.AllVersions { - log.Printf("creating container image for Tailscale (%s)", version) - - tsClient, err := tsic.New( - pool, - version, - network, - ) - if err != nil { - log.Fatalf("failed to create tailscale node: %s", err) - } - - err = tsClient.Shutdown() - if err != nil { - log.Fatalf("failed to shut down container: %s", err) - } - } - - network.Close() - err = pool.RemoveNetwork(network) - if err != nil { - log.Fatalf("failed to remove network: %s", err) - } -} diff --git a/cmd/headscale/cli/nodes.go b/cmd/headscale/cli/nodes.go index b1632d6c..ac996245 100644 --- a/cmd/headscale/cli/nodes.go +++ b/cmd/headscale/cli/nodes.go @@ -493,7 +493,7 @@ func nodesToPtables( "Ephemeral", "Last seen", "Expiration", - "Online", + "Connected", "Expired", } if showTags { diff --git a/flake.nix b/flake.nix index 44968325..44cf6fe8 100644 --- a/flake.nix +++ b/flake.nix @@ -31,7 +31,7 @@ # When updating go.mod or go.sum, a new sha will need to be calculated, # update this if you have a mismatch after doing a change to thos files. - vendorHash = "sha256-2ci6m1rKI3QdwbkqaGQlf0R+w4PhD0lkrLAu6wKj1LE="; + vendorHash = "sha256-7yqJbF0GkKa3wjiGWJ8BZSJyckrpwmCiX77/aoPGmRc="; ldflags = ["-s" "-w" "-X github.com/juanfont/headscale/cmd/headscale/cli.Version=v${version}"]; }; diff --git a/go.mod b/go.mod index 2d06cc99..dc6f3dac 100644 --- a/go.mod +++ b/go.mod @@ -35,6 +35,7 @@ require ( github.com/tcnksm/go-latest v0.0.0-20170313132115-e3007ae9052e go4.org/netipx v0.0.0-20230824141953-6213f710f925 golang.org/x/crypto v0.16.0 + golang.org/x/exp v0.0.0-20231127185646-65229373498e golang.org/x/net v0.19.0 golang.org/x/oauth2 v0.15.0 golang.org/x/sync v0.5.0 @@ -146,7 +147,6 @@ require ( github.com/xo/terminfo v0.0.0-20220910002029-abceb7e1c41e // indirect go.uber.org/multierr v1.11.0 // indirect go4.org/mem v0.0.0-20220726221520-4f986261bf13 // indirect - golang.org/x/exp v0.0.0-20231127185646-65229373498e // indirect golang.org/x/mod v0.14.0 // indirect golang.org/x/sys v0.15.0 // indirect golang.org/x/term v0.15.0 // indirect diff --git a/hscontrol/app.go b/hscontrol/app.go index 01ae3a78..9a879c82 100644 --- a/hscontrol/app.go +++ b/hscontrol/app.go @@ -59,6 +59,7 @@ var ( errUnsupportedLetsEncryptChallengeType = errors.New( "unknown value for Lets Encrypt challenge type", ) + errEmptyInitialDERPMap = errors.New("initial DERPMap is empty, Headscale requries at least one entry") ) const ( @@ -193,7 +194,10 @@ func NewHeadscale(cfg *types.Config) (*Headscale, error) { } if derpServerKey.Equal(*noisePrivateKey) { - return nil, fmt.Errorf("DERP server private key and noise private key are the same: %w", err) + return nil, fmt.Errorf( + "DERP server private key and noise private key are the same: %w", + err, + ) } embeddedDERPServer, err := derpServer.NewDERPServer( @@ -259,20 +263,13 @@ func (h *Headscale) scheduledDERPMapUpdateWorker(cancelChan <-chan struct{}) { h.DERPMap.Regions[region.RegionID] = ®ion } - h.nodeNotifier.NotifyAll(types.StateUpdate{ + stateUpdate := types.StateUpdate{ Type: types.StateDERPUpdated, - DERPMap: *h.DERPMap, - }) - } - } -} - -func (h *Headscale) failoverSubnetRoutes(milliSeconds int64) { - ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond) - for range ticker.C { - err := h.db.HandlePrimarySubnetFailover() - if err != nil { - log.Error().Err(err).Msg("failed to handle primary subnet failover") + DERPMap: h.DERPMap, + } + if stateUpdate.Valid() { + h.nodeNotifier.NotifyAll(stateUpdate) + } } } } @@ -505,13 +502,15 @@ func (h *Headscale) Serve() error { go h.scheduledDERPMapUpdateWorker(derpMapCancelChannel) } + if len(h.DERPMap.Regions) == 0 { + return errEmptyInitialDERPMap + } + // TODO(kradalby): These should have cancel channels and be cleaned // up on shutdown. go h.expireEphemeralNodes(updateInterval) go h.expireExpiredMachines(updateInterval) - go h.failoverSubnetRoutes(updateInterval) - if zl.GlobalLevel() == zl.TraceLevel { zerolog.RespLog = true } else { diff --git a/hscontrol/auth.go b/hscontrol/auth.go index 5022f65a..4fe5a16b 100644 --- a/hscontrol/auth.go +++ b/hscontrol/auth.go @@ -16,6 +16,46 @@ import ( "tailscale.com/types/key" ) +func logAuthFunc( + registerRequest tailcfg.RegisterRequest, + machineKey key.MachinePublic, +) (func(string), func(string), func(error, string)) { + return func(msg string) { + log.Info(). + Caller(). + Str("machine_key", machineKey.ShortString()). + Str("node_key", registerRequest.NodeKey.ShortString()). + Str("node_key_old", registerRequest.OldNodeKey.ShortString()). + Str("node", registerRequest.Hostinfo.Hostname). + Str("followup", registerRequest.Followup). + Time("expiry", registerRequest.Expiry). + Msg(msg) + }, + func(msg string) { + log.Trace(). + Caller(). + Str("machine_key", machineKey.ShortString()). + Str("node_key", registerRequest.NodeKey.ShortString()). + Str("node_key_old", registerRequest.OldNodeKey.ShortString()). + Str("node", registerRequest.Hostinfo.Hostname). + Str("followup", registerRequest.Followup). + Time("expiry", registerRequest.Expiry). + Msg(msg) + }, + func(err error, msg string) { + log.Error(). + Caller(). + Str("machine_key", machineKey.ShortString()). + Str("node_key", registerRequest.NodeKey.ShortString()). + Str("node_key_old", registerRequest.OldNodeKey.ShortString()). + Str("node", registerRequest.Hostinfo.Hostname). + Str("followup", registerRequest.Followup). + Time("expiry", registerRequest.Expiry). + Err(err). + Msg(msg) + } +} + // handleRegister is the logic for registering a client. func (h *Headscale) handleRegister( writer http.ResponseWriter, @@ -23,8 +63,11 @@ func (h *Headscale) handleRegister( registerRequest tailcfg.RegisterRequest, machineKey key.MachinePublic, ) { + logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey) now := time.Now().UTC() + logTrace("handleRegister called, looking up machine in DB") node, err := h.db.GetNodeByAnyKey(machineKey, registerRequest.NodeKey, registerRequest.OldNodeKey) + logTrace("handleRegister database lookup has returned") if errors.Is(err, gorm.ErrRecordNotFound) { // If the node has AuthKey set, handle registration via PreAuthKeys if registerRequest.Auth.AuthKey != "" { @@ -42,15 +85,9 @@ func (h *Headscale) handleRegister( // is that the client will hammer headscale with requests until it gets a // successful RegisterResponse. if registerRequest.Followup != "" { + logTrace("register request is a followup") if _, ok := h.registrationCache.Get(machineKey.String()); ok { - log.Debug(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Str("machine_key", machineKey.ShortString()). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Str("follow_up", registerRequest.Followup). - Msg("Node is waiting for interactive login") + logTrace("Node is waiting for interactive login") select { case <-req.Context().Done(): @@ -63,26 +100,14 @@ func (h *Headscale) handleRegister( } } - log.Info(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Str("machine_key", machineKey.ShortString()). - Str("node_key", registerRequest.NodeKey.ShortString()). - Str("node_key_old", registerRequest.OldNodeKey.ShortString()). - Str("follow_up", registerRequest.Followup). - Msg("New node not yet in the database") + logInfo("Node not found in database, creating new") givenName, err := h.db.GenerateGivenName( machineKey, registerRequest.Hostinfo.Hostname, ) if err != nil { - log.Error(). - Caller(). - Str("func", "RegistrationHandler"). - Str("hostinfo.name", registerRequest.Hostinfo.Hostname). - Err(err). - Msg("Failed to generate given name for node") + logErr(err, "Failed to generate given name for node") return } @@ -101,11 +126,7 @@ func (h *Headscale) handleRegister( } if !registerRequest.Expiry.IsZero() { - log.Trace(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Time("expiry", registerRequest.Expiry). - Msg("Non-zero expiry time requested") + logTrace("Non-zero expiry time requested") newNode.Expiry = ®isterRequest.Expiry } @@ -419,13 +440,12 @@ func (h *Headscale) handleNewNode( registerRequest tailcfg.RegisterRequest, machineKey key.MachinePublic, ) { + logInfo, logTrace, logErr := logAuthFunc(registerRequest, machineKey) + resp := tailcfg.RegisterResponse{} // The node registration is new, redirect the client to the registration URL - log.Debug(). - Caller(). - Str("node", registerRequest.Hostinfo.Hostname). - Msg("The node seems to be new, sending auth url") + logTrace("The node seems to be new, sending auth url") if h.oauth2Config != nil { resp.AuthURL = fmt.Sprintf( @@ -441,10 +461,7 @@ func (h *Headscale) handleNewNode( respBody, err := json.Marshal(resp) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Cannot encode message") + logErr(err, "Cannot encode message") http.Error(writer, "Internal server error", http.StatusInternalServerError) return @@ -454,17 +471,10 @@ func (h *Headscale) handleNewNode( writer.WriteHeader(http.StatusOK) _, err = writer.Write(respBody) if err != nil { - log.Error(). - Caller(). - Err(err). - Msg("Failed to write response") + logErr(err, "Failed to write response") } - log.Info(). - Caller(). - Str("AuthURL", resp.AuthURL). - Str("node", registerRequest.Hostinfo.Hostname). - Msg("Successfully sent auth url") + logInfo(fmt.Sprintf("Successfully sent auth url: %s", resp.AuthURL)) } func (h *Headscale) handleNodeLogOut( @@ -490,6 +500,19 @@ func (h *Headscale) handleNodeLogOut( return } + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: tailcfg.NodeID(node.ID), + KeyExpiry: &now, + }, + }, + } + if stateUpdate.Valid() { + h.nodeNotifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) + } + resp.AuthURL = "" resp.MachineAuthorized = false resp.NodeKeyExpired = true diff --git a/hscontrol/db/db.go b/hscontrol/db/db.go index e4774480..55782764 100644 --- a/hscontrol/db/db.go +++ b/hscontrol/db/db.go @@ -171,11 +171,13 @@ func NewHeadscaleDatabase( dKey = "discokey:" + node.DiscoKey } - err := db.db.Exec("UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id", + err := db.db.Exec( + "UPDATE nodes SET machine_key = @mKey, node_key = @nKey, disco_key = @dKey WHERE ID = @id", sql.Named("mKey", mKey), sql.Named("nKey", nKey), sql.Named("dKey", dKey), - sql.Named("id", node.ID)).Error + sql.Named("id", node.ID), + ).Error if err != nil { return nil, err } diff --git a/hscontrol/db/node.go b/hscontrol/db/node.go index bc122ee9..ac0e0b38 100644 --- a/hscontrol/db/node.go +++ b/hscontrol/db/node.go @@ -61,11 +61,6 @@ func (hsdb *HSDatabase) listPeers(node *types.Node) (types.Nodes, error) { sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID }) - log.Trace(). - Caller(). - Str("node", node.Hostname). - Msgf("Found peers: %s", nodes.String()) - return nodes, nil } @@ -176,6 +171,12 @@ func (hsdb *HSDatabase) GetNodeByMachineKey( hsdb.mu.RLock() defer hsdb.mu.RUnlock() + return hsdb.getNodeByMachineKey(machineKey) +} + +func (hsdb *HSDatabase) getNodeByMachineKey( + machineKey key.MachinePublic, +) (*types.Node, error) { mach := types.Node{} if result := hsdb.db. Preload("AuthKey"). @@ -252,6 +253,10 @@ func (hsdb *HSDatabase) SetTags( hsdb.mu.Lock() defer hsdb.mu.Unlock() + if len(tags) == 0 { + return nil + } + newTags := []string{} for _, tag := range tags { if !util.StringOrPrefixListContains(newTags, tag) { @@ -265,10 +270,14 @@ func (hsdb *HSDatabase) SetTags( return fmt.Errorf("failed to update tags for node in the database: %w", err) } - hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: types.Nodes{node}, - }, node.MachineKey.String()) + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from db.SetTags", + } + if stateUpdate.Valid() { + hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) + } return nil } @@ -301,10 +310,14 @@ func (hsdb *HSDatabase) RenameNode(node *types.Node, newName string) error { return fmt.Errorf("failed to rename node in the database: %w", err) } - hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: types.Nodes{node}, - }, node.MachineKey.String()) + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from db.RenameNode", + } + if stateUpdate.Valid() { + hsdb.notifier.NotifyWithIgnore(stateUpdate, node.MachineKey.String()) + } return nil } @@ -327,10 +340,18 @@ func (hsdb *HSDatabase) nodeSetExpiry(node *types.Node, expiry time.Time) error ) } - hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: types.Nodes{node}, - }, node.MachineKey.String()) + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: tailcfg.NodeID(node.ID), + KeyExpiry: &expiry, + }, + }, + } + if stateUpdate.Valid() { + hsdb.notifier.NotifyAll(stateUpdate) + } return nil } @@ -354,10 +375,13 @@ func (hsdb *HSDatabase) deleteNode(node *types.Node) error { return err } - hsdb.notifier.NotifyAll(types.StateUpdate{ + stateUpdate := types.StateUpdate{ Type: types.StatePeerRemoved, Removed: []tailcfg.NodeID{tailcfg.NodeID(node.ID)}, - }) + } + if stateUpdate.Valid() { + hsdb.notifier.NotifyAll(stateUpdate) + } return nil } @@ -629,20 +653,6 @@ func (hsdb *HSDatabase) IsRoutesEnabled(node *types.Node, routeStr string) bool return false } -func (hsdb *HSDatabase) ListOnlineNodes( - node *types.Node, -) (map[tailcfg.NodeID]bool, error) { - hsdb.mu.RLock() - defer hsdb.mu.RUnlock() - - peers, err := hsdb.listPeers(node) - if err != nil { - return nil, err - } - - return peers.OnlineNodeMap(), nil -} - // enableRoutes enables new routes based on a list of new routes. func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) error { newRoutes := make([]netip.Prefix, len(routeStrs)) @@ -694,10 +704,30 @@ func (hsdb *HSDatabase) enableRoutes(node *types.Node, routeStrs ...string) erro } } - hsdb.notifier.NotifyWithIgnore(types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: types.Nodes{node}, - }, node.MachineKey.String()) + // Ensure the node has the latest routes when notifying the other + // nodes + nRoutes, err := hsdb.getNodeRoutes(node) + if err != nil { + return fmt.Errorf("failed to read back routes: %w", err) + } + + node.Routes = nRoutes + + log.Trace(). + Caller(). + Str("node", node.Hostname). + Strs("routes", routeStrs). + Msg("enabling routes") + + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from db.enableRoutes", + } + if stateUpdate.Valid() { + hsdb.notifier.NotifyWithIgnore( + stateUpdate, node.MachineKey.String()) + } return nil } @@ -728,7 +758,10 @@ func generateGivenName(suppliedName string, randomSuffix bool) (string, error) { return normalizedHostname, nil } -func (hsdb *HSDatabase) GenerateGivenName(mkey key.MachinePublic, suppliedName string) (string, error) { +func (hsdb *HSDatabase) GenerateGivenName( + mkey key.MachinePublic, + suppliedName string, +) (string, error) { hsdb.mu.RLock() defer hsdb.mu.RUnlock() @@ -823,53 +856,54 @@ func (hsdb *HSDatabase) ExpireExpiredNodes(lastCheck time.Time) time.Time { // checked everything. started := time.Now() - users, err := hsdb.listUsers() + expired := make([]*tailcfg.PeerChange, 0) + + nodes, err := hsdb.listNodes() if err != nil { - log.Error().Err(err).Msg("Error listing users") + log.Error(). + Err(err). + Msg("Error listing nodes to find expired nodes") return time.Unix(0, 0) } + for index, node := range nodes { + if node.IsExpired() && + // TODO(kradalby): Replace this, it is very spammy + // It will notify about all nodes that has been expired. + // It should only notify about expired nodes since _last check_. + node.Expiry.After(lastCheck) { + expired = append(expired, &tailcfg.PeerChange{ + NodeID: tailcfg.NodeID(node.ID), + KeyExpiry: node.Expiry, + }) - for _, user := range users { - nodes, err := hsdb.listNodesByUser(user.Name) - if err != nil { - log.Error(). - Err(err). - Str("user", user.Name). - Msg("Error listing nodes in user") - - return time.Unix(0, 0) - } - - expired := make([]tailcfg.NodeID, 0) - for index, node := range nodes { - if node.IsExpired() && - node.Expiry.After(lastCheck) { - expired = append(expired, tailcfg.NodeID(node.ID)) - - now := time.Now() - err := hsdb.nodeSetExpiry(nodes[index], now) - if err != nil { - log.Error(). - Err(err). - Str("node", node.Hostname). - Str("name", node.GivenName). - Msg("🤮 Cannot expire node") - } else { - log.Info(). - Str("node", node.Hostname). - Str("name", node.GivenName). - Msg("Node successfully expired") - } + now := time.Now() + // Do not use setNodeExpiry as that has a notifier hook, which + // can cause a deadlock, we are updating all changed nodes later + // and there is no point in notifiying twice. + if err := hsdb.db.Model(nodes[index]).Updates(types.Node{ + Expiry: &now, + }).Error; err != nil { + log.Error(). + Err(err). + Str("node", node.Hostname). + Str("name", node.GivenName). + Msg("🤮 Cannot expire node") + } else { + log.Info(). + Str("node", node.Hostname). + Str("name", node.GivenName). + Msg("Node successfully expired") } } + } - if len(expired) > 0 { - hsdb.notifier.NotifyAll(types.StateUpdate{ - Type: types.StatePeerRemoved, - Removed: expired, - }) - } + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChangedPatch, + ChangePatches: expired, + } + if stateUpdate.Valid() { + hsdb.notifier.NotifyAll(stateUpdate) } return started diff --git a/hscontrol/db/node_test.go b/hscontrol/db/node_test.go index be13f66d..140c264b 100644 --- a/hscontrol/db/node_test.go +++ b/hscontrol/db/node_test.go @@ -603,8 +603,9 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) { db.db.Save(&node) - err = db.SaveNodeRoutes(&node) + sendUpdate, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) + c.Assert(sendUpdate, check.Equals, false) node0ByID, err := db.GetNodeByID(0) c.Assert(err, check.IsNil) diff --git a/hscontrol/db/routes.go b/hscontrol/db/routes.go index 545bd2fa..51c7f3bc 100644 --- a/hscontrol/db/routes.go +++ b/hscontrol/db/routes.go @@ -7,7 +7,9 @@ import ( "github.com/juanfont/headscale/hscontrol/policy" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" + "github.com/samber/lo" "gorm.io/gorm" + "tailscale.com/types/key" ) var ErrRouteIsNotAvailable = errors.New("route is not available") @@ -21,7 +23,38 @@ func (hsdb *HSDatabase) GetRoutes() (types.Routes, error) { func (hsdb *HSDatabase) getRoutes() (types.Routes, error) { var routes types.Routes - err := hsdb.db.Preload("Node").Find(&routes).Error + err := hsdb.db. + Preload("Node"). + Preload("Node.User"). + Find(&routes).Error + if err != nil { + return nil, err + } + + return routes, nil +} + +func (hsdb *HSDatabase) getAdvertisedAndEnabledRoutes() (types.Routes, error) { + var routes types.Routes + err := hsdb.db. + Preload("Node"). + Preload("Node.User"). + Where("advertised = ? AND enabled = ?", true, true). + Find(&routes).Error + if err != nil { + return nil, err + } + + return routes, nil +} + +func (hsdb *HSDatabase) getRoutesByPrefix(pref netip.Prefix) (types.Routes, error) { + var routes types.Routes + err := hsdb.db. + Preload("Node"). + Preload("Node.User"). + Where("prefix = ?", types.IPPrefix(pref)). + Find(&routes).Error if err != nil { return nil, err } @@ -40,6 +73,7 @@ func (hsdb *HSDatabase) getNodeAdvertisedRoutes(node *types.Node) (types.Routes, var routes types.Routes err := hsdb.db. Preload("Node"). + Preload("Node.User"). Where("node_id = ? AND advertised = true", node.ID). Find(&routes).Error if err != nil { @@ -60,6 +94,7 @@ func (hsdb *HSDatabase) getNodeRoutes(node *types.Node) (types.Routes, error) { var routes types.Routes err := hsdb.db. Preload("Node"). + Preload("Node.User"). Where("node_id = ?", node.ID). Find(&routes).Error if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { @@ -78,7 +113,10 @@ func (hsdb *HSDatabase) GetRoute(id uint64) (*types.Route, error) { func (hsdb *HSDatabase) getRoute(id uint64) (*types.Route, error) { var route types.Route - err := hsdb.db.Preload("Node").First(&route, id).Error + err := hsdb.db. + Preload("Node"). + Preload("Node.User"). + First(&route, id).Error if err != nil { return nil, err } @@ -122,37 +160,61 @@ func (hsdb *HSDatabase) DisableRoute(id uint64) error { return err } + var routes types.Routes + node := route.Node + // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 if !route.IsExitRoute() { + err = hsdb.failoverRouteWithNotify(route) + if err != nil { + return err + } + route.Enabled = false route.IsPrimary = false err = hsdb.db.Save(route).Error if err != nil { return err } + } else { + routes, err = hsdb.getNodeRoutes(&node) + if err != nil { + return err + } - return hsdb.handlePrimarySubnetFailover() - } - - routes, err := hsdb.getNodeRoutes(&route.Node) - if err != nil { - return err - } - - for i := range routes { - if routes[i].IsExitRoute() { - routes[i].Enabled = false - routes[i].IsPrimary = false - err = hsdb.db.Save(&routes[i]).Error - if err != nil { - return err + for i := range routes { + if routes[i].IsExitRoute() { + routes[i].Enabled = false + routes[i].IsPrimary = false + err = hsdb.db.Save(&routes[i]).Error + if err != nil { + return err + } } } } - return hsdb.handlePrimarySubnetFailover() + if routes == nil { + routes, err = hsdb.getNodeRoutes(&node) + if err != nil { + return err + } + } + + node.Routes = routes + + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{&node}, + Message: "called from db.DisableRoute", + } + if stateUpdate.Valid() { + hsdb.notifier.NotifyAll(stateUpdate) + } + + return nil } func (hsdb *HSDatabase) DeleteRoute(id uint64) error { @@ -164,34 +226,58 @@ func (hsdb *HSDatabase) DeleteRoute(id uint64) error { return err } + var routes types.Routes + node := route.Node + // Tailscale requires both IPv4 and IPv6 exit routes to // be enabled at the same time, as per // https://github.com/juanfont/headscale/issues/804#issuecomment-1399314002 if !route.IsExitRoute() { + err := hsdb.failoverRouteWithNotify(route) + if err != nil { + return nil + } + if err := hsdb.db.Unscoped().Delete(&route).Error; err != nil { return err } + } else { + routes, err := hsdb.getNodeRoutes(&node) + if err != nil { + return err + } - return hsdb.handlePrimarySubnetFailover() - } + routesToDelete := types.Routes{} + for _, r := range routes { + if r.IsExitRoute() { + routesToDelete = append(routesToDelete, r) + } + } - routes, err := hsdb.getNodeRoutes(&route.Node) - if err != nil { - return err - } - - routesToDelete := types.Routes{} - for _, r := range routes { - if r.IsExitRoute() { - routesToDelete = append(routesToDelete, r) + if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil { + return err } } - if err := hsdb.db.Unscoped().Delete(&routesToDelete).Error; err != nil { - return err + if routes == nil { + routes, err = hsdb.getNodeRoutes(&node) + if err != nil { + return err + } } - return hsdb.handlePrimarySubnetFailover() + node.Routes = routes + + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{&node}, + Message: "called from db.DeleteRoute", + } + if stateUpdate.Valid() { + hsdb.notifier.NotifyAll(stateUpdate) + } + + return nil } func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error { @@ -204,9 +290,13 @@ func (hsdb *HSDatabase) deleteNodeRoutes(node *types.Node) error { if err := hsdb.db.Unscoped().Delete(&routes[i]).Error; err != nil { return err } + + // TODO(kradalby): This is a bit too aggressive, we could probably + // figure out which routes needs to be failed over rather than all. + hsdb.failoverRouteWithNotify(&routes[i]) } - return hsdb.handlePrimarySubnetFailover() + return nil } // isUniquePrefix returns if there is another node providing the same route already. @@ -259,18 +349,22 @@ func (hsdb *HSDatabase) GetNodePrimaryRoutes(node *types.Node) (types.Routes, er // SaveNodeRoutes takes a node and updates the database with // the new routes. -func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) error { +// It returns a bool wheter an update should be sent as the +// saved route impacts nodes. +func (hsdb *HSDatabase) SaveNodeRoutes(node *types.Node) (bool, error) { hsdb.mu.Lock() defer hsdb.mu.Unlock() return hsdb.saveNodeRoutes(node) } -func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error { +func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) (bool, error) { + sendUpdate := false + currentRoutes := types.Routes{} err := hsdb.db.Where("node_id = ?", node.ID).Find(¤tRoutes).Error if err != nil { - return err + return sendUpdate, err } advertisedRoutes := map[netip.Prefix]bool{} @@ -290,7 +384,14 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error { currentRoutes[pos].Advertised = true err := hsdb.db.Save(¤tRoutes[pos]).Error if err != nil { - return err + return sendUpdate, err + } + + // If a route that is newly "saved" is already + // enabled, set sendUpdate to true as it is now + // available. + if route.Enabled { + sendUpdate = true } } advertisedRoutes[netip.Prefix(route.Prefix)] = true @@ -299,7 +400,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error { currentRoutes[pos].Enabled = false err := hsdb.db.Save(¤tRoutes[pos]).Error if err != nil { - return err + return sendUpdate, err } } } @@ -314,7 +415,41 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error { } err := hsdb.db.Create(&route).Error if err != nil { - return err + return sendUpdate, err + } + } + } + + return sendUpdate, nil +} + +// EnsureFailoverRouteIsAvailable takes a node and checks if the node's route +// currently have a functioning host that exposes the network. +func (hsdb *HSDatabase) EnsureFailoverRouteIsAvailable(node *types.Node) error { + nodeRoutes, err := hsdb.getNodeRoutes(node) + if err != nil { + return nil + } + + for _, nodeRoute := range nodeRoutes { + routes, err := hsdb.getRoutesByPrefix(netip.Prefix(nodeRoute.Prefix)) + if err != nil { + return err + } + + for _, route := range routes { + if route.IsPrimary { + // if we have a primary route, and the node is connected + // nothing needs to be done. + if hsdb.notifier.IsConnected(route.Node.MachineKey) { + continue + } + + // if not, we need to failover the route + err := hsdb.failoverRouteWithNotify(&route) + if err != nil { + return err + } } } } @@ -322,133 +457,181 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error { return nil } -func (hsdb *HSDatabase) HandlePrimarySubnetFailover() error { - hsdb.mu.Lock() - defer hsdb.mu.Unlock() - - return hsdb.handlePrimarySubnetFailover() -} - -func (hsdb *HSDatabase) handlePrimarySubnetFailover() error { - // first, get all the enabled routes - var routes types.Routes - err := hsdb.db. - Preload("Node"). - Where("advertised = ? AND enabled = ?", true, true). - Find(&routes).Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - log.Error().Err(err).Msg("error getting routes") +func (hsdb *HSDatabase) FailoverNodeRoutesWithNotify(node *types.Node) error { + routes, err := hsdb.getNodeRoutes(node) + if err != nil { + return nil } - changedNodes := make(types.Nodes, 0) - for pos, route := range routes { - if route.IsExitRoute() { + var changedKeys []key.MachinePublic + + for _, route := range routes { + changed, err := hsdb.failoverRoute(&route) + if err != nil { + return err + } + + changedKeys = append(changedKeys, changed...) + } + + changedKeys = lo.Uniq(changedKeys) + + var nodes types.Nodes + + for _, key := range changedKeys { + node, err := hsdb.GetNodeByMachineKey(key) + if err != nil { + return err + } + + nodes = append(nodes, node) + } + + if nodes != nil { + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: nodes, + Message: "called from db.FailoverNodeRoutesWithNotify", + } + if stateUpdate.Valid() { + hsdb.notifier.NotifyAll(stateUpdate) + } + } + + return nil +} + +func (hsdb *HSDatabase) failoverRouteWithNotify(r *types.Route) error { + changedKeys, err := hsdb.failoverRoute(r) + if err != nil { + return err + } + + if len(changedKeys) == 0 { + return nil + } + + var nodes types.Nodes + + log.Trace(). + Str("hostname", r.Node.Hostname). + Msg("loading machines with new primary routes from db") + + for _, key := range changedKeys { + node, err := hsdb.getNodeByMachineKey(key) + if err != nil { + return err + } + + nodes = append(nodes, node) + } + + log.Trace(). + Str("hostname", r.Node.Hostname). + Msg("notifying peers about primary route change") + + if nodes != nil { + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: nodes, + Message: "called from db.failoverRouteWithNotify", + } + if stateUpdate.Valid() { + hsdb.notifier.NotifyAll(stateUpdate) + } + } + + log.Trace(). + Str("hostname", r.Node.Hostname). + Msg("notified peers about primary route change") + + return nil +} + +// failoverRoute takes a route that is no longer available, +// this can be either from: +// - being disabled +// - being deleted +// - host going offline +// +// and tries to find a new route to take over its place. +// If the given route was not primary, it returns early. +func (hsdb *HSDatabase) failoverRoute(r *types.Route) ([]key.MachinePublic, error) { + if r == nil { + return nil, nil + } + + // This route is not a primary route, and it isnt + // being served to nodes. + if !r.IsPrimary { + return nil, nil + } + + // We do not have to failover exit nodes + if r.IsExitRoute() { + return nil, nil + } + + routes, err := hsdb.getRoutesByPrefix(netip.Prefix(r.Prefix)) + if err != nil { + return nil, err + } + + var newPrimary *types.Route + + // Find a new suitable route + for idx, route := range routes { + if r.ID == route.ID { continue } - node := &route.Node - - if !route.IsPrimary { - _, err := hsdb.getPrimaryRoute(netip.Prefix(route.Prefix)) - if hsdb.isUniquePrefix(route) || errors.Is(err, gorm.ErrRecordNotFound) { - log.Info(). - Str("prefix", netip.Prefix(route.Prefix).String()). - Str("node", route.Node.GivenName). - Msg("Setting primary route") - routes[pos].IsPrimary = true - err := hsdb.db.Save(&routes[pos]).Error - if err != nil { - log.Error().Err(err).Msg("error marking route as primary") - - return err - } - - changedNodes = append(changedNodes, node) - - continue - } - } - - if route.IsPrimary { - if route.Node.IsOnline() { - continue - } - - // node offline, find a new primary - log.Info(). - Str("node", route.Node.Hostname). - Str("prefix", netip.Prefix(route.Prefix).String()). - Msgf("node offline, finding a new primary subnet") - - // find a new primary route - var newPrimaryRoutes types.Routes - err := hsdb.db. - Preload("Node"). - Where("prefix = ? AND node_id != ? AND advertised = ? AND enabled = ?", - route.Prefix, - route.NodeID, - true, true). - Find(&newPrimaryRoutes).Error - if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { - log.Error().Err(err).Msg("error finding new primary route") - - return err - } - - var newPrimaryRoute *types.Route - for pos, r := range newPrimaryRoutes { - if r.Node.IsOnline() { - newPrimaryRoute = &newPrimaryRoutes[pos] - - break - } - } - - if newPrimaryRoute == nil { - log.Warn(). - Str("node", route.Node.Hostname). - Str("prefix", netip.Prefix(route.Prefix).String()). - Msgf("no alternative primary route found") - - continue - } - - log.Info(). - Str("old_node", route.Node.Hostname). - Str("prefix", netip.Prefix(route.Prefix).String()). - Str("new_node", newPrimaryRoute.Node.Hostname). - Msgf("found new primary route") - - // disable the old primary route - routes[pos].IsPrimary = false - err = hsdb.db.Save(&routes[pos]).Error - if err != nil { - log.Error().Err(err).Msg("error disabling old primary route") - - return err - } - - // enable the new primary route - newPrimaryRoute.IsPrimary = true - err = hsdb.db.Save(&newPrimaryRoute).Error - if err != nil { - log.Error().Err(err).Msg("error enabling new primary route") - - return err - } - - changedNodes = append(changedNodes, node) + if hsdb.notifier.IsConnected(route.Node.MachineKey) { + newPrimary = &routes[idx] + break } } - if len(changedNodes) > 0 { - hsdb.notifier.NotifyAll(types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: changedNodes, - }) + // If a new route was not found/available, + // return with an error. + // We do not want to update the database as + // the one currently marked as primary is the + // best we got. + if newPrimary == nil { + return nil, nil } - return nil + log.Trace(). + Str("hostname", newPrimary.Node.Hostname). + Msg("found new primary, updating db") + + // Remove primary from the old route + r.IsPrimary = false + err = hsdb.db.Save(&r).Error + if err != nil { + log.Error().Err(err).Msg("error disabling new primary route") + + return nil, err + } + + log.Trace(). + Str("hostname", newPrimary.Node.Hostname). + Msg("removed primary from old route") + + // Set primary for the new primary + newPrimary.IsPrimary = true + err = hsdb.db.Save(&newPrimary).Error + if err != nil { + log.Error().Err(err).Msg("error enabling new primary route") + + return nil, err + } + + log.Trace(). + Str("hostname", newPrimary.Node.Hostname). + Msg("set primary to new route") + + // Return a list of the machinekeys of the changed nodes. + return []key.MachinePublic{r.Node.MachineKey, newPrimary.Node.MachineKey}, nil } // EnableAutoApprovedRoutes enables any routes advertised by a node that match the ACL autoApprovers policy. diff --git a/hscontrol/db/routes_test.go b/hscontrol/db/routes_test.go index 92730afa..d491b6a3 100644 --- a/hscontrol/db/routes_test.go +++ b/hscontrol/db/routes_test.go @@ -2,12 +2,19 @@ package db import ( "net/netip" + "os" + "testing" "time" + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/notifier" "github.com/juanfont/headscale/hscontrol/types" "github.com/juanfont/headscale/hscontrol/util" + "github.com/stretchr/testify/assert" "gopkg.in/check.v1" + "gorm.io/gorm" "tailscale.com/tailcfg" + "tailscale.com/types/key" ) func (s *Suite) TestGetRoutes(c *check.C) { @@ -37,8 +44,9 @@ func (s *Suite) TestGetRoutes(c *check.C) { } db.db.Save(&node) - err = db.SaveNodeRoutes(&node) + su, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) + c.Assert(su, check.Equals, false) advertisedRoutes, err := db.GetAdvertisedRoutes(&node) c.Assert(err, check.IsNil) @@ -85,8 +93,9 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) { } db.db.Save(&node) - err = db.SaveNodeRoutes(&node) + sendUpdate, err := db.SaveNodeRoutes(&node) c.Assert(err, check.IsNil) + c.Assert(sendUpdate, check.Equals, false) availableRoutes, err := db.GetAdvertisedRoutes(&node) c.Assert(err, check.IsNil) @@ -156,8 +165,9 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { } db.db.Save(&node1) - err = db.SaveNodeRoutes(&node1) + sendUpdate, err := db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) + c.Assert(sendUpdate, check.Equals, false) err = db.enableRoutes(&node1, route.String()) c.Assert(err, check.IsNil) @@ -178,8 +188,9 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { } db.db.Save(&node2) - err = db.SaveNodeRoutes(&node2) + sendUpdate, err = db.SaveNodeRoutes(&node2) c.Assert(err, check.IsNil) + c.Assert(sendUpdate, check.Equals, false) err = db.enableRoutes(&node2, route2.String()) c.Assert(err, check.IsNil) @@ -201,142 +212,6 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) { c.Assert(len(routes), check.Equals, 0) } -func (s *Suite) TestSubnetFailover(c *check.C) { - user, err := db.CreateUser("test") - c.Assert(err, check.IsNil) - - pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil) - c.Assert(err, check.IsNil) - - _, err = db.GetNode("test", "test_enable_route_node") - c.Assert(err, check.NotNil) - - prefix, err := netip.ParsePrefix( - "10.0.0.0/24", - ) - c.Assert(err, check.IsNil) - - prefix2, err := netip.ParsePrefix( - "150.0.10.0/25", - ) - c.Assert(err, check.IsNil) - - hostInfo1 := tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{prefix, prefix2}, - } - - now := time.Now() - node1 := types.Node{ - ID: 1, - Hostname: "test_enable_route_node", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - Hostinfo: &hostInfo1, - LastSeen: &now, - } - db.db.Save(&node1) - - err = db.SaveNodeRoutes(&node1) - c.Assert(err, check.IsNil) - - err = db.enableRoutes(&node1, prefix.String()) - c.Assert(err, check.IsNil) - - err = db.enableRoutes(&node1, prefix2.String()) - c.Assert(err, check.IsNil) - - err = db.HandlePrimarySubnetFailover() - c.Assert(err, check.IsNil) - - enabledRoutes1, err := db.GetEnabledRoutes(&node1) - c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes1), check.Equals, 2) - - route, err := db.getPrimaryRoute(prefix) - c.Assert(err, check.IsNil) - c.Assert(route.NodeID, check.Equals, node1.ID) - - hostInfo2 := tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{prefix2}, - } - node2 := types.Node{ - ID: 2, - Hostname: "test_enable_route_node", - UserID: user.ID, - RegisterMethod: util.RegisterMethodAuthKey, - AuthKeyID: uint(pak.ID), - Hostinfo: &hostInfo2, - LastSeen: &now, - } - db.db.Save(&node2) - - err = db.saveNodeRoutes(&node2) - c.Assert(err, check.IsNil) - - err = db.enableRoutes(&node2, prefix2.String()) - c.Assert(err, check.IsNil) - - err = db.HandlePrimarySubnetFailover() - c.Assert(err, check.IsNil) - - enabledRoutes1, err = db.GetEnabledRoutes(&node1) - c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes1), check.Equals, 2) - - enabledRoutes2, err := db.GetEnabledRoutes(&node2) - c.Assert(err, check.IsNil) - c.Assert(len(enabledRoutes2), check.Equals, 1) - - routes, err := db.GetNodePrimaryRoutes(&node1) - c.Assert(err, check.IsNil) - c.Assert(len(routes), check.Equals, 2) - - routes, err = db.GetNodePrimaryRoutes(&node2) - c.Assert(err, check.IsNil) - c.Assert(len(routes), check.Equals, 0) - - // lets make node1 lastseen 10 mins ago - before := now.Add(-10 * time.Minute) - node1.LastSeen = &before - err = db.db.Save(&node1).Error - c.Assert(err, check.IsNil) - - err = db.HandlePrimarySubnetFailover() - c.Assert(err, check.IsNil) - - routes, err = db.GetNodePrimaryRoutes(&node1) - c.Assert(err, check.IsNil) - c.Assert(len(routes), check.Equals, 1) - - routes, err = db.GetNodePrimaryRoutes(&node2) - c.Assert(err, check.IsNil) - c.Assert(len(routes), check.Equals, 1) - - node2.Hostinfo = &tailcfg.Hostinfo{ - RoutableIPs: []netip.Prefix{prefix, prefix2}, - } - err = db.db.Save(&node2).Error - c.Assert(err, check.IsNil) - - err = db.SaveNodeRoutes(&node2) - c.Assert(err, check.IsNil) - - err = db.enableRoutes(&node2, prefix.String()) - c.Assert(err, check.IsNil) - - err = db.HandlePrimarySubnetFailover() - c.Assert(err, check.IsNil) - - routes, err = db.GetNodePrimaryRoutes(&node1) - c.Assert(err, check.IsNil) - c.Assert(len(routes), check.Equals, 0) - - routes, err = db.GetNodePrimaryRoutes(&node2) - c.Assert(err, check.IsNil) - c.Assert(len(routes), check.Equals, 2) -} - func (s *Suite) TestDeleteRoutes(c *check.C) { user, err := db.CreateUser("test") c.Assert(err, check.IsNil) @@ -373,8 +248,9 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { } db.db.Save(&node1) - err = db.SaveNodeRoutes(&node1) + sendUpdate, err := db.SaveNodeRoutes(&node1) c.Assert(err, check.IsNil) + c.Assert(sendUpdate, check.Equals, false) err = db.enableRoutes(&node1, prefix.String()) c.Assert(err, check.IsNil) @@ -392,3 +268,362 @@ func (s *Suite) TestDeleteRoutes(c *check.C) { c.Assert(err, check.IsNil) c.Assert(len(enabledRoutes1), check.Equals, 1) } + +func TestFailoverRoute(t *testing.T) { + ipp := func(s string) types.IPPrefix { return types.IPPrefix(netip.MustParsePrefix(s)) } + + // TODO(kradalby): Count/verify updates + var sink chan types.StateUpdate + + go func() { + for range sink { + } + }() + + machineKeys := []key.MachinePublic{ + key.NewMachine().Public(), + key.NewMachine().Public(), + key.NewMachine().Public(), + key.NewMachine().Public(), + } + + tests := []struct { + name string + failingRoute types.Route + routes types.Routes + want []key.MachinePublic + wantErr bool + }{ + { + name: "no-route", + failingRoute: types.Route{}, + routes: types.Routes{}, + want: nil, + wantErr: false, + }, + { + name: "no-prime", + failingRoute: types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: false, + }, + routes: types.Routes{}, + want: nil, + wantErr: false, + }, + { + name: "exit-node", + failingRoute: types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("0.0.0.0/0"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + }, + routes: types.Routes{}, + want: nil, + wantErr: false, + }, + { + name: "no-failover-single-route", + failingRoute: types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + }, + routes: types.Routes{ + types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + }, + }, + want: nil, + wantErr: false, + }, + { + name: "failover-primary", + failingRoute: types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + }, + routes: types.Routes{ + types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + }, + types.Route{ + Model: gorm.Model{ + ID: 2, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[1], + }, + IsPrimary: false, + }, + }, + want: []key.MachinePublic{ + machineKeys[0], + machineKeys[1], + }, + wantErr: false, + }, + { + name: "failover-none-primary", + failingRoute: types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: false, + }, + routes: types.Routes{ + types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + }, + types.Route{ + Model: gorm.Model{ + ID: 2, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[1], + }, + IsPrimary: false, + }, + }, + want: nil, + wantErr: false, + }, + { + name: "failover-primary-multi-route", + failingRoute: types.Route{ + Model: gorm.Model{ + ID: 2, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[1], + }, + IsPrimary: true, + }, + routes: types.Routes{ + types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: false, + }, + types.Route{ + Model: gorm.Model{ + ID: 2, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[1], + }, + IsPrimary: true, + }, + types.Route{ + Model: gorm.Model{ + ID: 3, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[2], + }, + IsPrimary: false, + }, + }, + want: []key.MachinePublic{ + machineKeys[1], + machineKeys[0], + }, + wantErr: false, + }, + { + name: "failover-primary-no-online", + failingRoute: types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + }, + routes: types.Routes{ + types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + }, + // Offline + types.Route{ + Model: gorm.Model{ + ID: 2, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[3], + }, + IsPrimary: false, + }, + }, + want: nil, + wantErr: false, + }, + { + name: "failover-primary-one-not-online", + failingRoute: types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + }, + routes: types.Routes{ + types.Route{ + Model: gorm.Model{ + ID: 1, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[0], + }, + IsPrimary: true, + }, + // Offline + types.Route{ + Model: gorm.Model{ + ID: 2, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[3], + }, + IsPrimary: false, + }, + types.Route{ + Model: gorm.Model{ + ID: 3, + }, + Prefix: ipp("10.0.0.0/24"), + Node: types.Node{ + MachineKey: machineKeys[1], + }, + IsPrimary: true, + }, + }, + want: []key.MachinePublic{ + machineKeys[0], + machineKeys[1], + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "failover-db-test") + assert.NoError(t, err) + + notif := notifier.NewNotifier() + + db, err = NewHeadscaleDatabase( + "sqlite3", + tmpDir+"/headscale_test.db", + false, + notif, + []netip.Prefix{ + netip.MustParsePrefix("10.27.0.0/23"), + }, + "", + ) + assert.NoError(t, err) + + // Pretend that all the nodes are connected to control + for idx, key := range machineKeys { + // Pretend one node is offline + if idx == 3 { + continue + } + + notif.AddNode(key, sink) + } + + for _, route := range tt.routes { + if err := db.db.Save(&route).Error; err != nil { + t.Fatalf("failed to create route: %s", err) + } + } + + got, err := db.failoverRoute(&tt.failingRoute) + + if (err != nil) != tt.wantErr { + t.Errorf("failoverRoute() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { + t.Errorf("failoverRoute() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/derp/server/derp_server.go b/hscontrol/derp/server/derp_server.go index d59966b6..c92595d0 100644 --- a/hscontrol/derp/server/derp_server.go +++ b/hscontrol/derp/server/derp_server.go @@ -18,6 +18,7 @@ import ( "tailscale.com/net/stun" "tailscale.com/tailcfg" "tailscale.com/types/key" + "tailscale.com/types/logger" ) // fastStartHeader is the header (with value "1") that signals to the HTTP @@ -33,13 +34,19 @@ type DERPServer struct { tailscaleDERP *derp.Server } +func derpLogf() logger.Logf { + return func(format string, args ...any) { + log.Debug().Caller().Msgf(format, args...) + } +} + func NewDERPServer( serverURL string, derpKey key.NodePrivate, cfg *types.DERPConfig, ) (*DERPServer, error) { log.Trace().Caller().Msg("Creating new embedded DERP server") - server := derp.NewServer(derpKey, log.Debug().Msgf) // nolint // zerolinter complains + server := derp.NewServer(derpKey, derpLogf()) // nolint // zerolinter complains return &DERPServer{ serverURL: serverURL, diff --git a/hscontrol/grpcv1.go b/hscontrol/grpcv1.go index 9139513e..ffd3a576 100644 --- a/hscontrol/grpcv1.go +++ b/hscontrol/grpcv1.go @@ -204,7 +204,13 @@ func (api headscaleV1APIServer) GetNode( return nil, err } - return &v1.GetNodeResponse{Node: node.Proto()}, nil + resp := node.Proto() + + // Populate the online field based on + // currently connected nodes. + resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) + + return &v1.GetNodeResponse{Node: resp}, nil } func (api headscaleV1APIServer) SetTags( @@ -333,7 +339,13 @@ func (api headscaleV1APIServer) ListNodes( response := make([]*v1.Node, len(nodes)) for index, node := range nodes { - response[index] = node.Proto() + resp := node.Proto() + + // Populate the online field based on + // currently connected nodes. + resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) + + response[index] = resp } return &v1.ListNodesResponse{Nodes: response}, nil @@ -346,13 +358,18 @@ func (api headscaleV1APIServer) ListNodes( response := make([]*v1.Node, len(nodes)) for index, node := range nodes { - m := node.Proto() + resp := node.Proto() + + // Populate the online field based on + // currently connected nodes. + resp.Online = api.h.nodeNotifier.IsConnected(node.MachineKey) + validTags, invalidTags := api.h.ACLPolicy.TagsOfNode( &node, ) - m.InvalidTags = invalidTags - m.ValidTags = validTags - response[index] = m + resp.InvalidTags = invalidTags + resp.ValidTags = validTags + response[index] = resp } return &v1.ListNodesResponse{Nodes: response}, nil diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index 6aa1294d..0a848b8d 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -8,6 +8,7 @@ import ( "net/url" "os" "path" + "slices" "sort" "strings" "sync" @@ -21,6 +22,7 @@ import ( "github.com/klauspost/compress/zstd" "github.com/rs/zerolog/log" "github.com/samber/lo" + "golang.org/x/exp/maps" "tailscale.com/envknob" "tailscale.com/smallzstd" "tailscale.com/tailcfg" @@ -45,6 +47,7 @@ var debugDumpMapResponsePath = envknob.String("HEADSCALE_DEBUG_DUMP_MAPRESPONSE_ // - Keep information about the previous mapresponse so we can send a diff // - Store hashes // - Create a "minifier" that removes info not needed for the node +// - some sort of batching, wait for 5 or 60 seconds before sending type Mapper struct { // Configuration @@ -61,8 +64,14 @@ type Mapper struct { // Map isnt concurrency safe, so we need to ensure // only one func is accessing it over time. - mu sync.Mutex - peers map[uint64]*types.Node + mu sync.Mutex + peers map[uint64]*types.Node + patches map[uint64][]patch +} + +type patch struct { + timestamp time.Time + change *tailcfg.PeerChange } func NewMapper( @@ -93,7 +102,8 @@ func NewMapper( seq: 0, // TODO: populate - peers: peers.IDMap(), + peers: peers.IDMap(), + patches: make(map[uint64][]patch), } } @@ -235,6 +245,19 @@ func (m *Mapper) FullMapResponse( m.mu.Lock() defer m.mu.Unlock() + peers := maps.Keys(m.peers) + peersWithPatches := maps.Keys(m.patches) + slices.Sort(peers) + slices.Sort(peersWithPatches) + + if len(peersWithPatches) > 0 { + log.Debug(). + Str("node", node.Hostname). + Uints64("peers", peers). + Uints64("pending_patches", peersWithPatches). + Msgf("node requested full map response, but has pending patches") + } + resp, err := m.fullMapResponse(node, pol, mapRequest.Version) if err != nil { return nil, err @@ -272,10 +295,12 @@ func (m *Mapper) KeepAliveResponse( func (m *Mapper) DERPMapResponse( mapRequest tailcfg.MapRequest, node *types.Node, - derpMap tailcfg.DERPMap, + derpMap *tailcfg.DERPMap, ) ([]byte, error) { + m.derpMap = derpMap + resp := m.baseMapResponse() - resp.DERPMap = &derpMap + resp.DERPMap = derpMap return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) } @@ -285,18 +310,29 @@ func (m *Mapper) PeerChangedResponse( node *types.Node, changed types.Nodes, pol *policy.ACLPolicy, + messages ...string, ) ([]byte, error) { m.mu.Lock() defer m.mu.Unlock() - lastSeen := make(map[tailcfg.NodeID]bool) - // Update our internal map. for _, node := range changed { - m.peers[node.ID] = node + if patches, ok := m.patches[node.ID]; ok { + // preserve online status in case the patch has an outdated one + online := node.IsOnline - // We have just seen the node, let the peers update their list. - lastSeen[tailcfg.NodeID(node.ID)] = true + for _, p := range patches { + // TODO(kradalby): Figure if this needs to be sorted by timestamp + node.ApplyPeerChange(p.change) + } + + // Ensure the patches are not applied again later + delete(m.patches, node.ID) + + node.IsOnline = online + } + + m.peers[node.ID] = node } resp := m.baseMapResponse() @@ -316,11 +352,55 @@ func (m *Mapper) PeerChangedResponse( return nil, err } - // resp.PeerSeenChange = lastSeen + return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress, messages...) +} + +// PeerChangedPatchResponse creates a patch MapResponse with +// incoming update from a state change. +func (m *Mapper) PeerChangedPatchResponse( + mapRequest tailcfg.MapRequest, + node *types.Node, + changed []*tailcfg.PeerChange, + pol *policy.ACLPolicy, +) ([]byte, error) { + m.mu.Lock() + defer m.mu.Unlock() + + sendUpdate := false + // patch the internal map + for _, change := range changed { + if peer, ok := m.peers[uint64(change.NodeID)]; ok { + peer.ApplyPeerChange(change) + sendUpdate = true + } else { + log.Trace().Str("node", node.Hostname).Msgf("Node with ID %s is missing from mapper for Node %s, saving patch for when node is available", change.NodeID, node.Hostname) + + p := patch{ + timestamp: time.Now(), + change: change, + } + + if patches, ok := m.patches[uint64(change.NodeID)]; ok { + patches := append(patches, p) + + m.patches[uint64(change.NodeID)] = patches + } else { + m.patches[uint64(change.NodeID)] = []patch{p} + } + } + } + + if !sendUpdate { + return nil, nil + } + + resp := m.baseMapResponse() + resp.PeersChangedPatch = changed return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) } +// TODO(kradalby): We need some integration tests for this. func (m *Mapper) PeerRemovedResponse( mapRequest tailcfg.MapRequest, node *types.Node, @@ -329,13 +409,23 @@ func (m *Mapper) PeerRemovedResponse( m.mu.Lock() defer m.mu.Unlock() + // Some nodes might have been removed already + // so we dont want to ask downstream to remove + // twice, than can cause a panic in tailscaled. + notYetRemoved := []tailcfg.NodeID{} + // remove from our internal map for _, id := range removed { + if _, ok := m.peers[uint64(id)]; ok { + notYetRemoved = append(notYetRemoved, id) + } + delete(m.peers, uint64(id)) + delete(m.patches, uint64(id)) } resp := m.baseMapResponse() - resp.PeersRemoved = removed + resp.PeersRemoved = notYetRemoved return m.marshalMapResponse(mapRequest, &resp, node, mapRequest.Compress) } @@ -345,6 +435,7 @@ func (m *Mapper) marshalMapResponse( resp *tailcfg.MapResponse, node *types.Node, compression string, + messages ...string, ) ([]byte, error) { atomic.AddUint64(&m.seq, 1) @@ -358,11 +449,25 @@ func (m *Mapper) marshalMapResponse( if debugDumpMapResponsePath != "" { data := map[string]interface{}{ + "Messages": messages, "MapRequest": mapRequest, "MapResponse": resp, } - body, err := json.Marshal(data) + responseType := "keepalive" + + switch { + case resp.Peers != nil && len(resp.Peers) > 0: + responseType = "full" + case resp.PeersChanged != nil && len(resp.PeersChanged) > 0: + responseType = "changed" + case resp.PeersChangedPatch != nil && len(resp.PeersChangedPatch) > 0: + responseType = "patch" + case resp.PeersRemoved != nil && len(resp.PeersRemoved) > 0: + responseType = "removed" + } + + body, err := json.MarshalIndent(data, "", " ") if err != nil { log.Error(). Caller(). @@ -381,7 +486,7 @@ func (m *Mapper) marshalMapResponse( mapResponsePath := path.Join( mPath, - fmt.Sprintf("%d-%s-%d.json", now, m.uid, atomic.LoadUint64(&m.seq)), + fmt.Sprintf("%d-%s-%d-%s.json", now, m.uid, atomic.LoadUint64(&m.seq), responseType), ) log.Trace().Msgf("Writing MapResponse to %s", mapResponsePath) @@ -438,6 +543,7 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse { resp := tailcfg.MapResponse{ KeepAlive: false, ControlTime: &now, + // TODO(kradalby): Implement PingRequest? } return resp @@ -559,8 +665,5 @@ func appendPeerChanges( resp.UserProfiles = profiles resp.SSHPolicy = sshPolicy - // TODO(kradalby): This currently does not take last seen in keepalives into account - resp.OnlineChange = peers.OnlineNodeMap() - return nil } diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index a5a5dceb..bcc17dd4 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -237,7 +237,6 @@ func Test_fullMapResponse(t *testing.T) { Tags: []string{}, PrimaryRoutes: []netip.Prefix{netip.MustParsePrefix("192.168.0.0/24")}, LastSeen: &lastSeen, - Online: new(bool), MachineAuthorized: true, Capabilities: []tailcfg.NodeCapability{ tailcfg.CapabilityFileSharing, @@ -293,7 +292,6 @@ func Test_fullMapResponse(t *testing.T) { Tags: []string{}, PrimaryRoutes: []netip.Prefix{}, LastSeen: &lastSeen, - Online: new(bool), MachineAuthorized: true, Capabilities: []tailcfg.NodeCapability{ tailcfg.CapabilityFileSharing, @@ -400,7 +398,6 @@ func Test_fullMapResponse(t *testing.T) { DNSConfig: &tailcfg.DNSConfig{}, Domain: "", CollectServices: "false", - OnlineChange: map[tailcfg.NodeID]bool{tailPeer1.ID: false}, PacketFilter: []tailcfg.FilterRule{}, UserProfiles: []tailcfg.UserProfile{{LoginName: "mini", DisplayName: "mini"}}, SSHPolicy: &tailcfg.SSHPolicy{Rules: []*tailcfg.SSHRule{}}, @@ -442,10 +439,6 @@ func Test_fullMapResponse(t *testing.T) { DNSConfig: &tailcfg.DNSConfig{}, Domain: "", CollectServices: "false", - OnlineChange: map[tailcfg.NodeID]bool{ - tailPeer1.ID: false, - tailcfg.NodeID(peer2.ID): false, - }, PacketFilter: []tailcfg.FilterRule{ { SrcIPs: []string{"100.64.0.2/32"}, diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index a4367720..e213a951 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -87,11 +87,9 @@ func tailNode( hostname, err := node.GetFQDN(dnsConfig, baseDomain) if err != nil { - return nil, err + return nil, fmt.Errorf("tailNode, failed to create FQDN: %s", err) } - online := node.IsOnline() - tags, _ := pol.TagsOfNode(node) tags = lo.Uniq(append(tags, node.ForcedTags...)) @@ -101,6 +99,7 @@ func tailNode( strconv.FormatUint(node.ID, util.Base10), ), // in headscale, unlike tailcontrol server, IDs are permanent Name: hostname, + Cap: capVer, User: tailcfg.UserID(node.UserID), @@ -116,13 +115,14 @@ func tailNode( Hostinfo: node.Hostinfo.View(), Created: node.CreatedAt, + Online: node.IsOnline, + Tags: tags, PrimaryRoutes: primaryPrefixes, - LastSeen: node.LastSeen, - Online: &online, MachineAuthorized: !node.IsExpired(), + Expired: node.IsExpired(), } // - 74: 2023-09-18: Client understands NodeCapMap @@ -153,5 +153,11 @@ func tailNode( tNode.Capabilities = append(tNode.Capabilities, tailcfg.NodeAttrDisableUPnP) } + if node.IsOnline == nil || !*node.IsOnline { + // LastSeen is only set when node is + // not connected to the control server. + tNode.LastSeen = node.LastSeen + } + return &tNode, nil } diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 936f2756..f6e370c4 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -68,7 +68,6 @@ func TestTailNode(t *testing.T) { Hostinfo: hiview(tailcfg.Hostinfo{}), Tags: []string{}, PrimaryRoutes: []netip.Prefix{}, - Online: new(bool), MachineAuthorized: true, Capabilities: []tailcfg.NodeCapability{ "https://tailscale.com/cap/file-sharing", "https://tailscale.com/cap/is-admin", @@ -165,7 +164,6 @@ func TestTailNode(t *testing.T) { }, LastSeen: &lastSeen, - Online: new(bool), MachineAuthorized: true, Capabilities: []tailcfg.NodeCapability{ diff --git a/hscontrol/notifier/notifier.go b/hscontrol/notifier/notifier.go index fee0befb..ae0aad46 100644 --- a/hscontrol/notifier/notifier.go +++ b/hscontrol/notifier/notifier.go @@ -1,6 +1,8 @@ package notifier import ( + "fmt" + "strings" "sync" "github.com/juanfont/headscale/hscontrol/types" @@ -56,6 +58,19 @@ func (n *Notifier) RemoveNode(machineKey key.MachinePublic) { Msg("Removed channel") } +// IsConnected reports if a node is connected to headscale and has a +// poll session open. +func (n *Notifier) IsConnected(machineKey key.MachinePublic) bool { + n.l.RLock() + defer n.l.RUnlock() + + if _, ok := n.nodes[machineKey.String()]; ok { + return true + } + + return false +} + func (n *Notifier) NotifyAll(update types.StateUpdate) { n.NotifyWithIgnore(update) } @@ -79,3 +94,16 @@ func (n *Notifier) NotifyWithIgnore(update types.StateUpdate, ignore ...string) c <- update } } + +func (n *Notifier) String() string { + n.l.RLock() + defer n.l.RUnlock() + + str := []string{"Notifier, in map:\n"} + + for k, v := range n.nodes { + str = append(str, fmt.Sprintf("\t%s: %v\n", k, v)) + } + + return strings.Join(str, "") +} diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index aca1b499..c048778d 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -14,29 +14,8 @@ import ( "go4.org/netipx" "gopkg.in/check.v1" "tailscale.com/tailcfg" - "tailscale.com/types/key" ) -var ipComparer = cmp.Comparer(func(x, y netip.Addr) bool { - return x.Compare(y) == 0 -}) - -var mkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool { - return x.String() == y.String() -}) - -var nkeyComparer = cmp.Comparer(func(x, y key.NodePublic) bool { - return x.String() == y.String() -}) - -var dkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool { - return x.String() == y.String() -}) - -var keyComparers []cmp.Option = []cmp.Option{ - mkeyComparer, nkeyComparer, dkeyComparer, -} - func Test(t *testing.T) { check.TestingT(t) } @@ -969,7 +948,7 @@ func Test_listNodesInUser(t *testing.T) { t.Run(test.name, func(t *testing.T) { got := filterNodesByUser(test.args.nodes, test.args.user) - if diff := cmp.Diff(test.want, got, keyComparers...); diff != "" { + if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { t.Errorf("listNodesInUser() = (-want +got):\n%s", diff) } }) @@ -1733,7 +1712,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) { test.args.nodes, test.args.user, ) - if diff := cmp.Diff(test.want, got, ipComparer, mkeyComparer, nkeyComparer, dkeyComparer); diff != "" { + if diff := cmp.Diff(test.want, got, util.Comparers...); diff != "" { t.Errorf("excludeCorrectlyTaggedNodes() (-want +got):\n%s", diff) } }) @@ -2085,10 +2064,6 @@ func Test_getTags(t *testing.T) { } func Test_getFilteredByACLPeers(t *testing.T) { - ipComparer := cmp.Comparer(func(x, y netip.Addr) bool { - return x.Compare(y) == 0 - }) - type args struct { nodes types.Nodes rules []tailcfg.FilterRule @@ -2752,7 +2727,7 @@ func Test_getFilteredByACLPeers(t *testing.T) { tt.args.nodes, tt.args.rules, ) - if diff := cmp.Diff(tt.want, got, ipComparer, mkeyComparer, nkeyComparer, dkeyComparer); diff != "" { + if diff := cmp.Diff(tt.want, got, util.Comparers...); diff != "" { t.Errorf("FilterNodesByACL() unexpected result (-want +got):\n%s", diff) } }) diff --git a/hscontrol/poll.go b/hscontrol/poll.go index 31801952..a07fda08 100644 --- a/hscontrol/poll.go +++ b/hscontrol/poll.go @@ -9,6 +9,7 @@ import ( "github.com/juanfont/headscale/hscontrol/mapper" "github.com/juanfont/headscale/hscontrol/types" "github.com/rs/zerolog/log" + xslices "golang.org/x/exp/slices" "tailscale.com/tailcfg" ) @@ -61,7 +62,7 @@ func (h *Headscale) handlePoll( ) { logInfo, logErr := logPollFunc(mapRequest, node) - // This is the mechanism where the node gives us inforamtion about its + // This is the mechanism where the node gives us information about its // current configuration. // // If OmitPeers is true, Stream is false, and ReadOnly is false, @@ -69,6 +70,7 @@ func (h *Headscale) handlePoll( // breaking existing long-polling (Stream == true) connections. // In this case, the server can omit the entire response; the client // only checks the HTTP response status code. + // TODO(kradalby): remove ReadOnly when we only support capVer 68+ if mapRequest.OmitPeers && !mapRequest.Stream && !mapRequest.ReadOnly { log.Info(). Caller(). @@ -78,14 +80,85 @@ func (h *Headscale) handlePoll( Str("node_key", node.NodeKey.ShortString()). Str("node", node.Hostname). Int("cap_ver", int(mapRequest.Version)). - Msg("Received endpoint update") + Msg("Received update") - now := time.Now().UTC() - node.LastSeen = &now - node.Hostname = mapRequest.Hostinfo.Hostname - node.Hostinfo = mapRequest.Hostinfo - node.DiscoKey = mapRequest.DiscoKey - node.Endpoints = mapRequest.Endpoints + change := node.PeerChangeFromMapRequest(mapRequest) + + online := h.nodeNotifier.IsConnected(node.MachineKey) + change.Online = &online + + node.ApplyPeerChange(&change) + + hostInfoChange := node.Hostinfo.Equal(mapRequest.Hostinfo) + + logTracePeerChange(node.Hostname, hostInfoChange, &change) + + // Check if the Hostinfo of the node has changed. + // If it has changed, check if there has been a change tod + // the routable IPs of the host and update update them in + // the database. Then send a Changed update + // (containing the whole node object) to peers to inform about + // the route change. + // If the hostinfo has changed, but not the routes, just update + // hostinfo and let the function continue. + if !hostInfoChange { + oldRoutes := node.Hostinfo.RoutableIPs + newRoutes := mapRequest.Hostinfo.RoutableIPs + + oldServicesCount := len(node.Hostinfo.Services) + newServicesCount := len(mapRequest.Hostinfo.Services) + + node.Hostinfo = mapRequest.Hostinfo + + sendUpdate := false + + // Route changes come as part of Hostinfo, which means that + // when an update comes, the Node Route logic need to run. + // This will require a "change" in comparison to a "patch", + // which is more costly. + if !xslices.Equal(oldRoutes, newRoutes) { + var err error + sendUpdate, err = h.db.SaveNodeRoutes(node) + if err != nil { + logErr(err, "Error processing node routes") + http.Error(writer, "", http.StatusInternalServerError) + + return + } + } + + // Services is mostly useful for discovery and not critical, + // except for peerapi, which is how nodes talk to eachother. + // If peerapi was not part of the initial mapresponse, we + // need to make sure its sent out later as it is needed for + // Taildrop. + // TODO(kradalby): Length comparison is a bit naive, replace. + if oldServicesCount != newServicesCount { + sendUpdate = true + } + + if sendUpdate { + if err := h.db.NodeSave(node); err != nil { + logErr(err, "Failed to persist/update node in the database") + http.Error(writer, "", http.StatusInternalServerError) + + return + } + + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from handlePoll -> update -> new hostinfo", + } + if stateUpdate.Valid() { + h.nodeNotifier.NotifyWithIgnore( + stateUpdate, + node.MachineKey.String()) + } + + return + } + } if err := h.db.NodeSave(node); err != nil { logErr(err, "Failed to persist/update node in the database") @@ -94,20 +167,15 @@ func (h *Headscale) handlePoll( return } - err := h.db.SaveNodeRoutes(node) - if err != nil { - logErr(err, "Error processing node routes") - http.Error(writer, "", http.StatusInternalServerError) - - return + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{&change}, + } + if stateUpdate.Valid() { + h.nodeNotifier.NotifyWithIgnore( + stateUpdate, + node.MachineKey.String()) } - - h.nodeNotifier.NotifyWithIgnore( - types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: types.Nodes{node}, - }, - node.MachineKey.String()) writer.WriteHeader(http.StatusOK) if f, ok := writer.(http.Flusher); ok { @@ -115,7 +183,7 @@ func (h *Headscale) handlePoll( } return - + } else if mapRequest.OmitPeers && !mapRequest.Stream && mapRequest.ReadOnly { // ReadOnly is whether the client just wants to fetch the // MapResponse, without updating their Endpoints. The // Endpoints field will be ignored and LastSeen will not be @@ -133,12 +201,39 @@ func (h *Headscale) handlePoll( return } - now := time.Now().UTC() - node.LastSeen = &now - node.Hostname = mapRequest.Hostinfo.Hostname - node.Hostinfo = mapRequest.Hostinfo - node.DiscoKey = mapRequest.DiscoKey - node.Endpoints = mapRequest.Endpoints + change := node.PeerChangeFromMapRequest(mapRequest) + + // A stream is being set up, the node is Online + online := true + change.Online = &online + + node.ApplyPeerChange(&change) + + // Only save HostInfo if changed, update routes if changed + // TODO(kradalby): Remove when capver is over 68 + if !node.Hostinfo.Equal(mapRequest.Hostinfo) { + oldRoutes := node.Hostinfo.RoutableIPs + newRoutes := mapRequest.Hostinfo.RoutableIPs + + node.Hostinfo = mapRequest.Hostinfo + + if !xslices.Equal(oldRoutes, newRoutes) { + _, err := h.db.SaveNodeRoutes(node) + if err != nil { + logErr(err, "Error processing node routes") + http.Error(writer, "", http.StatusInternalServerError) + + return + } + } + } + + if err := h.db.NodeSave(node); err != nil { + logErr(err, "Failed to persist/update node in the database") + http.Error(writer, "", http.StatusInternalServerError) + + return + } // When a node connects to control, list the peers it has at // that given point, further updates are kept in memory in @@ -152,6 +247,11 @@ func (h *Headscale) handlePoll( return } + for _, peer := range peers { + online := h.nodeNotifier.IsConnected(peer.MachineKey) + peer.IsOnline = &online + } + mapp := mapper.NewMapper( node, peers, @@ -162,11 +262,6 @@ func (h *Headscale) handlePoll( h.cfg.RandomizeClientPort, ) - err = h.db.SaveNodeRoutes(node) - if err != nil { - logErr(err, "Error processing node routes") - } - // update ACLRules with peer informations (to update server tags if necessary) if h.ACLPolicy != nil { // update routes with peer information @@ -176,14 +271,6 @@ func (h *Headscale) handlePoll( } } - // TODO(kradalby): Save specific stuff, not whole object. - if err := h.db.NodeSave(node); err != nil { - logErr(err, "Failed to persist/update node in the database") - http.Error(writer, "", http.StatusInternalServerError) - - return - } - logInfo("Sending initial map") mapResp, err := mapp.FullMapResponse(mapRequest, node, h.ACLPolicy) @@ -208,18 +295,26 @@ func (h *Headscale) handlePoll( return } - h.nodeNotifier.NotifyWithIgnore( - types.StateUpdate{ - Type: types.StatePeerChanged, - Changed: types.Nodes{node}, - }, - node.MachineKey.String()) + stateUpdate := types.StateUpdate{ + Type: types.StatePeerChanged, + ChangeNodes: types.Nodes{node}, + Message: "called from handlePoll -> new node added", + } + if stateUpdate.Valid() { + h.nodeNotifier.NotifyWithIgnore( + stateUpdate, + node.MachineKey.String()) + } // Set up the client stream h.pollNetMapStreamWG.Add(1) defer h.pollNetMapStreamWG.Done() - updateChan := make(chan types.StateUpdate) + // Use a buffered channel in case a node is not fully ready + // to receive a message to make sure we dont block the entire + // notifier. + // 12 is arbitrarily chosen. + updateChan := make(chan types.StateUpdate, 12) defer closeChanWithLog(updateChan, node.Hostname, "updateChan") // Register the node's update channel @@ -233,6 +328,10 @@ func (h *Headscale) handlePoll( ctx, cancel := context.WithCancel(ctx) defer cancel() + if len(node.Routes) > 0 { + go h.db.EnsureFailoverRouteIsAvailable(node) + } + for { logInfo("Waiting for update on stream channel") select { @@ -262,14 +361,7 @@ func (h *Headscale) handlePoll( // One alternative is to split these different channels into // goroutines, but then you might have a problem without a lock // if a keepalive is written at the same time as an update. - go func() { - err = h.db.UpdateLastSeen(node) - if err != nil { - logErr(err, "Cannot update node LastSeen") - - return - } - }() + go h.updateNodeOnlineStatus(true, node) case update := <-updateChan: logInfo("Received update") @@ -279,18 +371,35 @@ func (h *Headscale) handlePoll( var err error switch update.Type { + case types.StateFullUpdate: + logInfo("Sending Full MapResponse") + + data, err = mapp.FullMapResponse(mapRequest, node, h.ACLPolicy) case types.StatePeerChanged: - logInfo("Sending PeerChanged MapResponse") - data, err = mapp.PeerChangedResponse(mapRequest, node, update.Changed, h.ACLPolicy) + logInfo(fmt.Sprintf("Sending Changed MapResponse: %s", update.Message)) + + for _, node := range update.ChangeNodes { + // If a node is not reported to be online, it might be + // because the value is outdated, check with the notifier. + // However, if it is set to Online, and not in the notifier, + // this might be because it has announced itself, but not + // reached the stage to actually create the notifier channel. + if node.IsOnline != nil && !*node.IsOnline { + isOnline := h.nodeNotifier.IsConnected(node.MachineKey) + node.IsOnline = &isOnline + } + } + + data, err = mapp.PeerChangedResponse(mapRequest, node, update.ChangeNodes, h.ACLPolicy, update.Message) + case types.StatePeerChangedPatch: + logInfo("Sending PeerChangedPatch MapResponse") + data, err = mapp.PeerChangedPatchResponse(mapRequest, node, update.ChangePatches, h.ACLPolicy) case types.StatePeerRemoved: logInfo("Sending PeerRemoved MapResponse") data, err = mapp.PeerRemovedResponse(mapRequest, node, update.Removed) case types.StateDERPUpdated: logInfo("Sending DERPUpdate MapResponse") data, err = mapp.DERPMapResponse(mapRequest, node, update.DERPMap) - case types.StateFullUpdate: - logInfo("Sending Full MapResponse") - data, err = mapp.FullMapResponse(mapRequest, node, h.ACLPolicy) } if err != nil { @@ -299,54 +408,45 @@ func (h *Headscale) handlePoll( return } - _, err = writer.Write(data) - if err != nil { - logErr(err, "Could not write the map response") - - updateRequestsSentToNode.WithLabelValues(node.User.Name, node.Hostname, "failed"). - Inc() - - return - } - - if flusher, ok := writer.(http.Flusher); ok { - flusher.Flush() - } else { - log.Error().Msg("Failed to create http flusher") - - return - } - - // See comment in keepAliveTicker - go func() { - err = h.db.UpdateLastSeen(node) + // Only send update if there is change + if data != nil { + _, err = writer.Write(data) if err != nil { - logErr(err, "Cannot update node LastSeen") + logErr(err, "Could not write the map response") + + updateRequestsSentToNode.WithLabelValues(node.User.Name, node.Hostname, "failed"). + Inc() return } - }() - log.Info(). - Caller(). - Bool("readOnly", mapRequest.ReadOnly). - Bool("omitPeers", mapRequest.OmitPeers). - Bool("stream", mapRequest.Stream). - Str("node_key", node.NodeKey.ShortString()). - Str("node", node.Hostname). - TimeDiff("timeSpent", time.Now(), now). - Msg("update sent") + if flusher, ok := writer.(http.Flusher); ok { + flusher.Flush() + } else { + log.Error().Msg("Failed to create http flusher") + + return + } + + log.Info(). + Caller(). + Bool("readOnly", mapRequest.ReadOnly). + Bool("omitPeers", mapRequest.OmitPeers). + Bool("stream", mapRequest.Stream). + Str("node_key", node.NodeKey.ShortString()). + Str("machine_key", node.MachineKey.ShortString()). + Str("node", node.Hostname). + TimeDiff("timeSpent", time.Now(), now). + Msg("update sent") + } + case <-ctx.Done(): logInfo("The client has closed the connection") - go func() { - err = h.db.UpdateLastSeen(node) - if err != nil { - logErr(err, "Cannot update node LastSeen") + go h.updateNodeOnlineStatus(false, node) - return - } - }() + // Failover the node's routes if any. + go h.db.FailoverNodeRoutesWithNotify(node) // The connection has been closed, so we can stop polling. return @@ -359,6 +459,36 @@ func (h *Headscale) handlePoll( } } +// updateNodeOnlineStatus records the last seen status of a node and notifies peers +// about change in their online/offline status. +// It takes a StateUpdateType of either StatePeerOnlineChanged or StatePeerOfflineChanged. +func (h *Headscale) updateNodeOnlineStatus(online bool, node *types.Node) { + now := time.Now() + + node.LastSeen = &now + + statusUpdate := types.StateUpdate{ + Type: types.StatePeerChangedPatch, + ChangePatches: []*tailcfg.PeerChange{ + { + NodeID: tailcfg.NodeID(node.ID), + Online: &online, + LastSeen: &now, + }, + }, + } + if statusUpdate.Valid() { + h.nodeNotifier.NotifyWithIgnore(statusUpdate, node.MachineKey.String()) + } + + err := h.db.UpdateLastSeen(node) + if err != nil { + log.Error().Err(err).Msg("Cannot update node LastSeen") + + return + } +} + func closeChanWithLog[C chan []byte | chan struct{} | chan types.StateUpdate](channel C, node, name string) { log.Trace(). Str("handler", "PollNetMap"). @@ -378,8 +508,6 @@ func (h *Headscale) handleLiteRequest( mapp := mapper.NewMapper( node, - // TODO(kradalby): It might not be acceptable to send - // an empty peer list here. types.Nodes{}, h.DERPMap, h.cfg.BaseDomain, @@ -405,3 +533,38 @@ func (h *Headscale) handleLiteRequest( logErr(err, "Failed to write response") } } + +func logTracePeerChange(hostname string, hostinfoChange bool, change *tailcfg.PeerChange) { + trace := log.Trace().Str("node_id", change.NodeID.String()).Str("hostname", hostname) + + if change.Key != nil { + trace = trace.Str("node_key", change.Key.ShortString()) + } + + if change.DiscoKey != nil { + trace = trace.Str("disco_key", change.DiscoKey.ShortString()) + } + + if change.Online != nil { + trace = trace.Bool("online", *change.Online) + } + + if change.Endpoints != nil { + eps := make([]string, len(change.Endpoints)) + for idx, ep := range change.Endpoints { + eps[idx] = ep.String() + } + + trace = trace.Strs("endpoints", eps) + } + + if hostinfoChange { + trace = trace.Bool("hostinfo_changed", hostinfoChange) + } + + if change.DERPRegion != 0 { + trace = trace.Int("derp_region", change.DERPRegion) + } + + trace.Time("last_seen", *change.LastSeen).Msg("PeerChange received") +} diff --git a/hscontrol/types/common.go b/hscontrol/types/common.go index 39060ac5..6e8bfff8 100644 --- a/hscontrol/types/common.go +++ b/hscontrol/types/common.go @@ -84,20 +84,31 @@ type StateUpdateType int const ( StateFullUpdate StateUpdateType = iota + // StatePeerChanged is used for updates that needs + // to be calculated with all peers and all policy rules. + // This would typically be things that include tags, routes + // and similar. StatePeerChanged + StatePeerChangedPatch StatePeerRemoved StateDERPUpdated ) // StateUpdate is an internal message containing information about // a state change that has happened to the network. +// If type is StateFullUpdate, all fields are ignored. type StateUpdate struct { // The type of update Type StateUpdateType - // Changed must be set when Type is StatePeerChanged and - // contain the Node IDs of nodes that have changed. - Changed Nodes + // ChangeNodes must be set when Type is StatePeerAdded + // and StatePeerChanged and contains the full node + // object for added nodes. + ChangeNodes Nodes + + // ChangePatches must be set when Type is StatePeerChangedPatch + // and contains a populated PeerChange object. + ChangePatches []*tailcfg.PeerChange // Removed must be set when Type is StatePeerRemoved and // contain a list of the nodes that has been removed from @@ -106,5 +117,36 @@ type StateUpdate struct { // DERPMap must be set when Type is StateDERPUpdated and // contain the new DERP Map. - DERPMap tailcfg.DERPMap + DERPMap *tailcfg.DERPMap + + // Additional message for tracking origin or what being + // updated, useful for ambiguous updates like StatePeerChanged. + Message string +} + +// Valid reports if a StateUpdate is correctly filled and +// panics if the mandatory fields for a type is not +// filled. +// Reports true if valid. +func (su *StateUpdate) Valid() bool { + switch su.Type { + case StatePeerChanged: + if su.ChangeNodes == nil { + panic("Mandatory field ChangeNodes is not set on StatePeerChanged update") + } + case StatePeerChangedPatch: + if su.ChangePatches == nil { + panic("Mandatory field ChangePatches is not set on StatePeerChangedPatch update") + } + case StatePeerRemoved: + if su.Removed == nil { + panic("Mandatory field Removed is not set on StatePeerRemove update") + } + case StateDERPUpdated: + if su.DERPMap == nil { + panic("Mandatory field DERPMap is not set on StateDERPUpdated update") + } + } + + return true } diff --git a/hscontrol/types/node.go b/hscontrol/types/node.go index a2fdb916..bb88fc32 100644 --- a/hscontrol/types/node.go +++ b/hscontrol/types/node.go @@ -21,7 +21,9 @@ import ( var ( ErrNodeAddressesInvalid = errors.New("failed to parse node addresses") - ErrHostnameTooLong = errors.New("hostname too long") + ErrHostnameTooLong = errors.New("hostname too long, cannot except 255 ASCII chars") + ErrNodeHasNoGivenName = errors.New("node has no given name") + ErrNodeUserHasNoName = errors.New("node user has no name") ) // Node is a Headscale client. @@ -95,22 +97,14 @@ type Node struct { CreatedAt time.Time UpdatedAt time.Time DeletedAt *time.Time + + IsOnline *bool `gorm:"-"` } type ( Nodes []*Node ) -func (nodes Nodes) OnlineNodeMap() map[tailcfg.NodeID]bool { - ret := make(map[tailcfg.NodeID]bool) - - for _, node := range nodes { - ret[tailcfg.NodeID(node.ID)] = node.IsOnline() - } - - return ret -} - type NodeAddresses []netip.Addr func (na NodeAddresses) Sort() { @@ -206,21 +200,6 @@ func (node Node) IsExpired() bool { return time.Now().UTC().After(*node.Expiry) } -// IsOnline returns if the node is connected to Headscale. -// This is really a naive implementation, as we don't really see -// if there is a working connection between the client and the server. -func (node *Node) IsOnline() bool { - if node.LastSeen == nil { - return false - } - - if node.IsExpired() { - return false - } - - return node.LastSeen.After(time.Now().Add(-KeepAliveInterval)) -} - // IsEphemeral returns if the node is registered as an Ephemeral node. // https://tailscale.com/kb/1111/ephemeral-nodes/ func (node *Node) IsEphemeral() bool { @@ -339,7 +318,6 @@ func (node *Node) Proto() *v1.Node { GivenName: node.GivenName, User: node.User.Proto(), ForcedTags: node.ForcedTags, - Online: node.IsOnline(), // TODO(kradalby): Implement register method enum converter // RegisterMethod: , @@ -365,6 +343,14 @@ func (node *Node) Proto() *v1.Node { func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (string, error) { var hostname string if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS + if node.GivenName == "" { + return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeHasNoGivenName) + } + + if node.User.Name == "" { + return "", fmt.Errorf("failed to create valid FQDN: %w", ErrNodeUserHasNoName) + } + hostname = fmt.Sprintf( "%s.%s.%s", node.GivenName, @@ -373,7 +359,7 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri ) if len(hostname) > MaxHostnameLength { return "", fmt.Errorf( - "hostname %q is too long it cannot except 255 ASCII chars: %w", + "failed to create valid FQDN (%s): %w", hostname, ErrHostnameTooLong, ) @@ -385,8 +371,98 @@ func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (stri return hostname, nil } -func (node Node) String() string { - return node.Hostname +// func (node *Node) String() string { +// return node.Hostname +// } + +// PeerChangeFromMapRequest takes a MapRequest and compares it to the node +// to produce a PeerChange struct that can be used to updated the node and +// inform peers about smaller changes to the node. +// When a field is added to this function, remember to also add it to: +// - node.ApplyPeerChange +// - logTracePeerChange in poll.go +func (node *Node) PeerChangeFromMapRequest(req tailcfg.MapRequest) tailcfg.PeerChange { + ret := tailcfg.PeerChange{ + NodeID: tailcfg.NodeID(node.ID), + } + + if node.NodeKey.String() != req.NodeKey.String() { + ret.Key = &req.NodeKey + } + + if node.DiscoKey.String() != req.DiscoKey.String() { + ret.DiscoKey = &req.DiscoKey + } + + if node.Hostinfo != nil && + node.Hostinfo.NetInfo != nil && + req.Hostinfo != nil && + req.Hostinfo.NetInfo != nil && + node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP { + ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP + } + + if req.Hostinfo != nil && req.Hostinfo.NetInfo != nil { + // If there is no stored Hostinfo or NetInfo, use + // the new PreferredDERP. + if node.Hostinfo == nil { + ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP + } else if node.Hostinfo.NetInfo == nil { + ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP + } else { + // If there is a PreferredDERP check if it has changed. + if node.Hostinfo.NetInfo.PreferredDERP != req.Hostinfo.NetInfo.PreferredDERP { + ret.DERPRegion = req.Hostinfo.NetInfo.PreferredDERP + } + } + } + + // TODO(kradalby): Find a good way to compare updates + ret.Endpoints = req.Endpoints + + now := time.Now() + ret.LastSeen = &now + + return ret +} + +// ApplyPeerChange takes a PeerChange struct and updates the node. +func (node *Node) ApplyPeerChange(change *tailcfg.PeerChange) { + if change.Key != nil { + node.NodeKey = *change.Key + } + + if change.DiscoKey != nil { + node.DiscoKey = *change.DiscoKey + } + + if change.Online != nil { + node.IsOnline = change.Online + } + + if change.Endpoints != nil { + node.Endpoints = change.Endpoints + } + + // This might technically not be useful as we replace + // the whole hostinfo blob when it has changed. + if change.DERPRegion != 0 { + if node.Hostinfo == nil { + node.Hostinfo = &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: change.DERPRegion, + }, + } + } else if node.Hostinfo.NetInfo == nil { + node.Hostinfo.NetInfo = &tailcfg.NetInfo{ + PreferredDERP: change.DERPRegion, + } + } else { + node.Hostinfo.NetInfo.PreferredDERP = change.DERPRegion + } + } + + node.LastSeen = change.LastSeen } func (nodes Nodes) String() string { diff --git a/hscontrol/types/node_test.go b/hscontrol/types/node_test.go index 85fa79c4..7e6c9840 100644 --- a/hscontrol/types/node_test.go +++ b/hscontrol/types/node_test.go @@ -4,7 +4,10 @@ import ( "net/netip" "testing" + "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "tailscale.com/tailcfg" + "tailscale.com/types/key" ) func Test_NodeCanAccess(t *testing.T) { @@ -139,3 +142,227 @@ func TestNodeAddressesOrder(t *testing.T) { } } } + +func TestNodeFQDN(t *testing.T) { + tests := []struct { + name string + node Node + dns tailcfg.DNSConfig + domain string + want string + wantErr string + }{ + { + name: "all-set", + node: Node{ + GivenName: "test", + User: User{ + Name: "user", + }, + }, + dns: tailcfg.DNSConfig{ + Proxied: true, + }, + domain: "example.com", + want: "test.user.example.com", + }, + { + name: "no-given-name", + node: Node{ + User: User{ + Name: "user", + }, + }, + dns: tailcfg.DNSConfig{ + Proxied: true, + }, + domain: "example.com", + wantErr: "failed to create valid FQDN: node has no given name", + }, + { + name: "no-user-name", + node: Node{ + GivenName: "test", + User: User{}, + }, + dns: tailcfg.DNSConfig{ + Proxied: true, + }, + domain: "example.com", + wantErr: "failed to create valid FQDN: node user has no name", + }, + { + name: "no-magic-dns", + node: Node{ + GivenName: "test", + User: User{ + Name: "user", + }, + }, + dns: tailcfg.DNSConfig{ + Proxied: false, + }, + domain: "example.com", + want: "test", + }, + { + name: "no-dnsconfig", + node: Node{ + GivenName: "test", + User: User{ + Name: "user", + }, + }, + domain: "example.com", + want: "test", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := tc.node.GetFQDN(&tc.dns, tc.domain) + + if (err != nil) && (err.Error() != tc.wantErr) { + t.Errorf("GetFQDN() error = %s, wantErr %s", err, tc.wantErr) + + return + } + + if diff := cmp.Diff(tc.want, got); diff != "" { + t.Errorf("GetFQDN unexpected result (-want +got):\n%s", diff) + } + }) + } +} + +func TestPeerChangeFromMapRequest(t *testing.T) { + nKeys := []key.NodePublic{ + key.NewNode().Public(), + key.NewNode().Public(), + key.NewNode().Public(), + } + + dKeys := []key.DiscoPublic{ + key.NewDisco().Public(), + key.NewDisco().Public(), + key.NewDisco().Public(), + } + + tests := []struct { + name string + node Node + mapReq tailcfg.MapRequest + want tailcfg.PeerChange + }{ + { + name: "preferred-derp-changed", + node: Node{ + ID: 1, + NodeKey: nKeys[0], + DiscoKey: dKeys[0], + Endpoints: []netip.AddrPort{}, + Hostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 998, + }, + }, + }, + mapReq: tailcfg.MapRequest{ + NodeKey: nKeys[0], + DiscoKey: dKeys[0], + Hostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 999, + }, + }, + }, + want: tailcfg.PeerChange{ + NodeID: 1, + DERPRegion: 999, + }, + }, + { + name: "preferred-derp-no-changed", + node: Node{ + ID: 1, + NodeKey: nKeys[0], + DiscoKey: dKeys[0], + Endpoints: []netip.AddrPort{}, + Hostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 100, + }, + }, + }, + mapReq: tailcfg.MapRequest{ + NodeKey: nKeys[0], + DiscoKey: dKeys[0], + Hostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 100, + }, + }, + }, + want: tailcfg.PeerChange{ + NodeID: 1, + DERPRegion: 0, + }, + }, + { + name: "preferred-derp-no-mapreq-netinfo", + node: Node{ + ID: 1, + NodeKey: nKeys[0], + DiscoKey: dKeys[0], + Endpoints: []netip.AddrPort{}, + Hostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 200, + }, + }, + }, + mapReq: tailcfg.MapRequest{ + NodeKey: nKeys[0], + DiscoKey: dKeys[0], + Hostinfo: &tailcfg.Hostinfo{}, + }, + want: tailcfg.PeerChange{ + NodeID: 1, + DERPRegion: 0, + }, + }, + { + name: "preferred-derp-no-node-netinfo", + node: Node{ + ID: 1, + NodeKey: nKeys[0], + DiscoKey: dKeys[0], + Endpoints: []netip.AddrPort{}, + Hostinfo: &tailcfg.Hostinfo{}, + }, + mapReq: tailcfg.MapRequest{ + NodeKey: nKeys[0], + DiscoKey: dKeys[0], + Hostinfo: &tailcfg.Hostinfo{ + NetInfo: &tailcfg.NetInfo{ + PreferredDERP: 200, + }, + }, + }, + want: tailcfg.PeerChange{ + NodeID: 1, + DERPRegion: 200, + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got := tc.node.PeerChangeFromMapRequest(tc.mapReq) + + if diff := cmp.Diff(tc.want, got, cmpopts.IgnoreFields(tailcfg.PeerChange{}, "LastSeen")); diff != "" { + t.Errorf("Patch unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/types/routes.go b/hscontrol/types/routes.go index 3fd96702..697cbc36 100644 --- a/hscontrol/types/routes.go +++ b/hscontrol/types/routes.go @@ -19,6 +19,8 @@ type Route struct { NodeID uint64 Node Node + + // TODO(kradalby): change this custom type to netip.Prefix Prefix IPPrefix Advertised bool @@ -29,13 +31,17 @@ type Route struct { type Routes []Route func (r *Route) String() string { - return fmt.Sprintf("%s:%s", r.Node, netip.Prefix(r.Prefix).String()) + return fmt.Sprintf("%s:%s", r.Node.Hostname, netip.Prefix(r.Prefix).String()) } func (r *Route) IsExitRoute() bool { return netip.Prefix(r.Prefix) == ExitRouteV4 || netip.Prefix(r.Prefix) == ExitRouteV6 } +func (r *Route) IsAnnouncable() bool { + return r.Advertised && r.Enabled +} + func (rs Routes) Prefixes() []netip.Prefix { prefixes := make([]netip.Prefix, len(rs)) for i, r := range rs { @@ -45,6 +51,32 @@ func (rs Routes) Prefixes() []netip.Prefix { return prefixes } +// Primaries returns Primary routes from a list of routes. +func (rs Routes) Primaries() Routes { + res := make(Routes, 0) + for _, route := range rs { + if route.IsPrimary { + res = append(res, route) + } + } + + return res +} + +func (rs Routes) PrefixMap() map[IPPrefix][]Route { + res := map[IPPrefix][]Route{} + + for _, route := range rs { + if _, ok := res[route.Prefix]; ok { + res[route.Prefix] = append(res[route.Prefix], route) + } else { + res[route.Prefix] = []Route{route} + } + } + + return res +} + func (rs Routes) Proto() []*v1.Route { protoRoutes := []*v1.Route{} diff --git a/hscontrol/types/routes_test.go b/hscontrol/types/routes_test.go new file mode 100644 index 00000000..ead4c595 --- /dev/null +++ b/hscontrol/types/routes_test.go @@ -0,0 +1,94 @@ +package types + +import ( + "fmt" + "net/netip" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/juanfont/headscale/hscontrol/util" +) + +func TestPrefixMap(t *testing.T) { + ipp := func(s string) IPPrefix { return IPPrefix(netip.MustParsePrefix(s)) } + + // TODO(kradalby): Remove when we have gotten rid of IPPrefix type + prefixComparer := cmp.Comparer(func(x, y IPPrefix) bool { + return x == y + }) + + tests := []struct { + rs Routes + want map[IPPrefix][]Route + }{ + { + rs: Routes{ + Route{ + Prefix: ipp("10.0.0.0/24"), + }, + }, + want: map[IPPrefix][]Route{ + ipp("10.0.0.0/24"): Routes{ + Route{ + Prefix: ipp("10.0.0.0/24"), + }, + }, + }, + }, + { + rs: Routes{ + Route{ + Prefix: ipp("10.0.0.0/24"), + }, + Route{ + Prefix: ipp("10.0.1.0/24"), + }, + }, + want: map[IPPrefix][]Route{ + ipp("10.0.0.0/24"): Routes{ + Route{ + Prefix: ipp("10.0.0.0/24"), + }, + }, + ipp("10.0.1.0/24"): Routes{ + Route{ + Prefix: ipp("10.0.1.0/24"), + }, + }, + }, + }, + { + rs: Routes{ + Route{ + Prefix: ipp("10.0.0.0/24"), + Enabled: true, + }, + Route{ + Prefix: ipp("10.0.0.0/24"), + Enabled: false, + }, + }, + want: map[IPPrefix][]Route{ + ipp("10.0.0.0/24"): Routes{ + Route{ + Prefix: ipp("10.0.0.0/24"), + Enabled: true, + }, + Route{ + Prefix: ipp("10.0.0.0/24"), + Enabled: false, + }, + }, + }, + }, + } + + for idx, tt := range tests { + t.Run(fmt.Sprintf("test-%d", idx), func(t *testing.T) { + got := tt.rs.PrefixMap() + if diff := cmp.Diff(tt.want, got, prefixComparer, util.MkeyComparer, util.NkeyComparer, util.DkeyComparer); diff != "" { + t.Errorf("PrefixMap() unexpected result (-want +got):\n%s", diff) + } + }) + } +} diff --git a/hscontrol/util/test.go b/hscontrol/util/test.go new file mode 100644 index 00000000..6d465426 --- /dev/null +++ b/hscontrol/util/test.go @@ -0,0 +1,32 @@ +package util + +import ( + "net/netip" + + "github.com/google/go-cmp/cmp" + "tailscale.com/types/key" +) + +var PrefixComparer = cmp.Comparer(func(x, y netip.Prefix) bool { + return x == y +}) + +var IPComparer = cmp.Comparer(func(x, y netip.Addr) bool { + return x.Compare(y) == 0 +}) + +var MkeyComparer = cmp.Comparer(func(x, y key.MachinePublic) bool { + return x.String() == y.String() +}) + +var NkeyComparer = cmp.Comparer(func(x, y key.NodePublic) bool { + return x.String() == y.String() +}) + +var DkeyComparer = cmp.Comparer(func(x, y key.DiscoPublic) bool { + return x.String() == y.String() +}) + +var Comparers []cmp.Option = []cmp.Option{ + IPComparer, PrefixComparer, MkeyComparer, NkeyComparer, DkeyComparer, +} diff --git a/integration/cli_test.go b/integration/cli_test.go index 6e7333ff..0ff0ffca 100644 --- a/integration/cli_test.go +++ b/integration/cli_test.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "sort" - "strconv" "testing" "time" @@ -22,7 +21,7 @@ func executeAndUnmarshal[T any](headscale ControlServer, command []string, resul err = json.Unmarshal([]byte(str), result) if err != nil { - return err + return fmt.Errorf("failed to unmarshal: %s\n command err: %s", err, str) } return nil @@ -178,7 +177,11 @@ func TestPreAuthKeyCommand(t *testing.T) { assert.Equal( t, []string{keys[0].GetId(), keys[1].GetId(), keys[2].GetId()}, - []string{listedPreAuthKeys[1].GetId(), listedPreAuthKeys[2].GetId(), listedPreAuthKeys[3].GetId()}, + []string{ + listedPreAuthKeys[1].GetId(), + listedPreAuthKeys[2].GetId(), + listedPreAuthKeys[3].GetId(), + }, ) assert.NotEmpty(t, listedPreAuthKeys[1].GetKey()) @@ -384,141 +387,6 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) { assert.Len(t, listedPreAuthKeys, 3) } -func TestEnablingRoutes(t *testing.T) { - IntegrationSkip(t) - t.Parallel() - - user := "enable-routing" - - scenario, err := NewScenario() - assertNoErrf(t, "failed to create scenario: %s", err) - defer scenario.Shutdown() - - spec := map[string]int{ - user: 3, - } - - err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute")) - assertNoErrHeadscaleEnv(t, err) - - allClients, err := scenario.ListTailscaleClients() - assertNoErrListClients(t, err) - - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) - - headscale, err := scenario.Headscale() - assertNoErrGetHeadscale(t, err) - - // advertise routes using the up command - for i, client := range allClients { - routeStr := fmt.Sprintf("10.0.%d.0/24", i) - command := []string{ - "tailscale", - "set", - "--advertise-routes=" + routeStr, - } - _, _, err := client.Execute(command) - assertNoErrf(t, "failed to advertise route: %s", err) - } - - err = scenario.WaitForTailscaleSync() - assertNoErrSync(t, err) - - var routes []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &routes, - ) - - assertNoErr(t, err) - assert.Len(t, routes, 3) - - for _, route := range routes { - assert.Equal(t, route.GetAdvertised(), true) - assert.Equal(t, route.GetEnabled(), false) - assert.Equal(t, route.GetIsPrimary(), false) - } - - for _, route := range routes { - _, err = headscale.Execute( - []string{ - "headscale", - "routes", - "enable", - "--route", - strconv.Itoa(int(route.GetId())), - }) - assertNoErr(t, err) - } - - var enablingRoutes []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &enablingRoutes, - ) - assertNoErr(t, err) - assert.Len(t, enablingRoutes, 3) - - for _, route := range enablingRoutes { - assert.Equal(t, route.GetAdvertised(), true) - assert.Equal(t, route.GetEnabled(), true) - assert.Equal(t, route.GetIsPrimary(), true) - } - - routeIDToBeDisabled := enablingRoutes[0].GetId() - - _, err = headscale.Execute( - []string{ - "headscale", - "routes", - "disable", - "--route", - strconv.Itoa(int(routeIDToBeDisabled)), - }) - assertNoErr(t, err) - - var disablingRoutes []*v1.Route - err = executeAndUnmarshal( - headscale, - []string{ - "headscale", - "routes", - "list", - "--output", - "json", - }, - &disablingRoutes, - ) - assertNoErr(t, err) - - for _, route := range disablingRoutes { - assert.Equal(t, true, route.GetAdvertised()) - - if route.GetId() == routeIDToBeDisabled { - assert.Equal(t, route.GetEnabled(), false) - assert.Equal(t, route.GetIsPrimary(), false) - } else { - assert.Equal(t, route.GetEnabled(), true) - assert.Equal(t, route.GetIsPrimary(), true) - } - } -} - func TestApiKeyCommand(t *testing.T) { IntegrationSkip(t) t.Parallel() diff --git a/integration/embedded_derp_test.go b/integration/embedded_derp_test.go index 4191a793..3a407496 100644 --- a/integration/embedded_derp_test.go +++ b/integration/embedded_derp_test.go @@ -44,6 +44,9 @@ func TestDERPServerScenario(t *testing.T) { headscaleConfig["HEADSCALE_DERP_SERVER_REGION_NAME"] = "Headscale Embedded DERP" headscaleConfig["HEADSCALE_DERP_SERVER_STUN_LISTEN_ADDR"] = "0.0.0.0:3478" headscaleConfig["HEADSCALE_DERP_SERVER_PRIVATE_KEY_PATH"] = "/tmp/derp.key" + // Envknob for enabling DERP debug logs + headscaleConfig["DERP_DEBUG_LOGS"] = "true" + headscaleConfig["DERP_PROBER_DEBUG_LOGS"] = "true" err = scenario.CreateHeadscaleEnv( spec, diff --git a/integration/general_test.go b/integration/general_test.go index afa93e74..2e0f7fe6 100644 --- a/integration/general_test.go +++ b/integration/general_test.go @@ -14,6 +14,8 @@ import ( "github.com/rs/zerolog/log" "github.com/samber/lo" "github.com/stretchr/testify/assert" + "tailscale.com/client/tailscale/apitype" + "tailscale.com/types/key" ) func TestPingAllByIP(t *testing.T) { @@ -248,9 +250,8 @@ func TestPingAllByHostname(t *testing.T) { defer scenario.Shutdown() spec := map[string]int{ - // Omit 1.16.2 (-1) because it does not have the FQDN field - "user3": len(MustTestVersions) - 1, - "user4": len(MustTestVersions) - 1, + "user3": len(MustTestVersions), + "user4": len(MustTestVersions), } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("pingallbyname")) @@ -296,8 +297,7 @@ func TestTaildrop(t *testing.T) { defer scenario.Shutdown() spec := map[string]int{ - // Omit 1.16.2 (-1) because it does not have the FQDN field - "taildrop": len(MustTestVersions) - 1, + "taildrop": len(MustTestVersions), } err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("taildrop")) @@ -313,6 +313,42 @@ func TestTaildrop(t *testing.T) { _, err = scenario.ListTailscaleClientsFQDNs() assertNoErrListFQDN(t, err) + for _, client := range allClients { + if !strings.Contains(client.Hostname(), "head") { + command := []string{"apk", "add", "curl"} + _, _, err := client.Execute(command) + if err != nil { + t.Fatalf("failed to install curl on %s, err: %s", client.Hostname(), err) + } + + } + curlCommand := []string{"curl", "--unix-socket", "/var/run/tailscale/tailscaled.sock", "http://local-tailscaled.sock/localapi/v0/file-targets"} + err = retry(10, 1*time.Second, func() error { + result, _, err := client.Execute(curlCommand) + if err != nil { + return err + } + var fts []apitype.FileTarget + err = json.Unmarshal([]byte(result), &fts) + if err != nil { + return err + } + + if len(fts) != len(allClients)-1 { + ftStr := fmt.Sprintf("FileTargets for %s:\n", client.Hostname()) + for _, ft := range fts { + ftStr += fmt.Sprintf("\t%s\n", ft.Node.Name) + } + return fmt.Errorf("client %s does not have all its peers as FileTargets, got %d, want: %d\n%s", client.Hostname(), len(fts), len(allClients)-1, ftStr) + } + + return err + }) + if err != nil { + t.Errorf("failed to query localapi for filetarget on %s, err: %s", client.Hostname(), err) + } + } + for _, client := range allClients { command := []string{"touch", fmt.Sprintf("/tmp/file_from_%s", client.Hostname())} @@ -347,8 +383,9 @@ func TestTaildrop(t *testing.T) { }) if err != nil { t.Fatalf( - "failed to send taildrop file on %s, err: %s", + "failed to send taildrop file on %s with command %q, err: %s", client.Hostname(), + strings.Join(command, " "), err, ) } @@ -517,25 +554,176 @@ func TestExpireNode(t *testing.T) { err = json.Unmarshal([]byte(result), &node) assertNoErr(t, err) + var expiredNodeKey key.NodePublic + err = expiredNodeKey.UnmarshalText([]byte(node.GetNodeKey())) + assertNoErr(t, err) + + t.Logf("Node %s with node_key %s has been expired", node.GetName(), expiredNodeKey.String()) + time.Sleep(30 * time.Second) - // Verify that the expired not is no longer present in the Peer list - // of connected nodes. + now := time.Now() + + // Verify that the expired node has been marked in all peers list. for _, client := range allClients { status, err := client.Status() assertNoErr(t, err) - for _, peerKey := range status.Peers() { - peerStatus := status.Peer[peerKey] - - peerPublicKey := strings.TrimPrefix(peerStatus.PublicKey.String(), "nodekey:") - - assert.NotEqual(t, node.GetNodeKey(), peerPublicKey) - } - if client.Hostname() != node.GetName() { - // Assert that we have the original count - self - expired node - assert.Len(t, status.Peers(), len(MustTestVersions)-2) + t.Logf("available peers of %s: %v", client.Hostname(), status.Peers()) + + // In addition to marking nodes expired, we filter them out during the map response + // this check ensures that the node is either not present, or that it is expired + // if it is in the map response. + if peerStatus, ok := status.Peer[expiredNodeKey]; ok { + assertNotNil(t, peerStatus.Expired) + assert.Truef(t, peerStatus.KeyExpiry.Before(now), "node %s should have a key expire before %s, was %s", peerStatus.HostName, now.String(), peerStatus.KeyExpiry) + assert.Truef(t, peerStatus.Expired, "node %s should be expired, expired is %v", peerStatus.HostName, peerStatus.Expired) + } + + // TODO(kradalby): We do not propogate expiry correctly, nodes should be aware + // of their status, and this should be sent directly to the node when its + // expired. This needs a notifier that goes directly to the node (currently we only do peers) + // so fix this in a follow up PR. + // } else { + // assert.True(t, status.Self.Expired) } } } + +func TestNodeOnlineLastSeenStatus(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + scenario, err := NewScenario() + assertNoErr(t, err) + defer scenario.Shutdown() + + spec := map[string]int{ + "user1": len(MustTestVersions), + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("onlinelastseen")) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + allIps, err := scenario.ListTailscaleClientsIPs() + assertNoErrListClientIPs(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + allAddrs := lo.Map(allIps, func(x netip.Addr, index int) string { + return x.String() + }) + + success := pingAllHelper(t, allClients, allAddrs) + t.Logf("before expire: %d successful pings out of %d", success, len(allClients)*len(allIps)) + + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + // Assert that we have the original count - self + assert.Len(t, status.Peers(), len(MustTestVersions)-1) + } + + headscale, err := scenario.Headscale() + assertNoErr(t, err) + + keepAliveInterval := 60 * time.Second + + // Duration is chosen arbitrarily, 10m is reported in #1561 + testDuration := 12 * time.Minute + start := time.Now() + end := start.Add(testDuration) + + log.Printf("Starting online test from %v to %v", start, end) + + for { + // Let the test run continuously for X minutes to verify + // all nodes stay connected and has the expected status over time. + if end.Before(time.Now()) { + return + } + + result, err := headscale.Execute([]string{ + "headscale", "nodes", "list", "--output", "json", + }) + assertNoErr(t, err) + + var nodes []*v1.Node + err = json.Unmarshal([]byte(result), &nodes) + assertNoErr(t, err) + + now := time.Now() + + // Threshold with some leeway + lastSeenThreshold := now.Add(-keepAliveInterval - (10 * time.Second)) + + // Verify that headscale reports the nodes as online + for _, node := range nodes { + // All nodes should be online + assert.Truef( + t, + node.GetOnline(), + "expected %s to have online status in Headscale, marked as offline %s after start", + node.GetName(), + time.Since(start), + ) + + lastSeen := node.GetLastSeen().AsTime() + // All nodes should have been last seen between now and the keepAliveInterval + assert.Truef( + t, + lastSeen.After(lastSeenThreshold), + "lastSeen (%v) was not %s after the threshold (%v)", + lastSeen, + keepAliveInterval, + lastSeenThreshold, + ) + } + + // Verify that all nodes report all nodes to be online + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + // .Online is only available from CapVer 16, which + // is not present in 1.18 which is the lowest we + // test. + if strings.Contains(client.Hostname(), "1-18") { + continue + } + + // All peers of this nodess are reporting to be + // connected to the control server + assert.Truef( + t, + peerStatus.Online, + "expected node %s to be marked as online in %s peer list, marked as offline %s after start", + peerStatus.HostName, + client.Hostname(), + time.Since(start), + ) + + // from docs: last seen to tailcontrol; only present if offline + // assert.Nilf( + // t, + // peerStatus.LastSeen, + // "expected node %s to not have LastSeen set, got %s", + // peerStatus.HostName, + // peerStatus.LastSeen, + // ) + } + } + + // Check maximum once per second + time.Sleep(time.Second) + } +} diff --git a/integration/route_test.go b/integration/route_test.go new file mode 100644 index 00000000..489165a8 --- /dev/null +++ b/integration/route_test.go @@ -0,0 +1,780 @@ +package integration + +import ( + "fmt" + "log" + "net/netip" + "sort" + "strconv" + "testing" + "time" + + v1 "github.com/juanfont/headscale/gen/go/headscale/v1" + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/assert" +) + +// This test is both testing the routes command and the propagation of +// routes. +func TestEnablingRoutes(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + user := "enable-routing" + + scenario, err := NewScenario() + assertNoErrf(t, "failed to create scenario: %s", err) + defer scenario.Shutdown() + + spec := map[string]int{ + user: 3, + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute")) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + expectedRoutes := map[string]string{ + "1": "10.0.0.0/24", + "2": "10.0.1.0/24", + "3": "10.0.2.0/24", + } + + // advertise routes using the up command + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + expectedRoutes[string(status.Self.ID)], + } + _, _, err = client.Execute(command) + assertNoErrf(t, "failed to advertise route: %s", err) + } + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + var routes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routes, + ) + + assertNoErr(t, err) + assert.Len(t, routes, 3) + + for _, route := range routes { + assert.Equal(t, route.GetAdvertised(), true) + assert.Equal(t, route.GetEnabled(), false) + assert.Equal(t, route.GetIsPrimary(), false) + } + + // Verify that no routes has been sent to the client, + // they are not yet enabled. + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Nil(t, peerStatus.PrimaryRoutes) + } + } + + // Enable all routes + for _, route := range routes { + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "enable", + "--route", + strconv.Itoa(int(route.GetId())), + }) + assertNoErr(t, err) + } + + var enablingRoutes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &enablingRoutes, + ) + assertNoErr(t, err) + assert.Len(t, enablingRoutes, 3) + + for _, route := range enablingRoutes { + assert.Equal(t, route.GetAdvertised(), true) + assert.Equal(t, route.GetEnabled(), true) + assert.Equal(t, route.GetIsPrimary(), true) + } + + time.Sleep(5 * time.Second) + + // Verify that the clients can see the new routes + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.NotNil(t, peerStatus.PrimaryRoutes) + if peerStatus.PrimaryRoutes == nil { + continue + } + + pRoutes := peerStatus.PrimaryRoutes.AsSlice() + + assert.Len(t, pRoutes, 1) + + if len(pRoutes) > 0 { + peerRoute := peerStatus.PrimaryRoutes.AsSlice()[0] + + // id starts at 1, we created routes with 0 index + assert.Equalf( + t, + expectedRoutes[string(peerStatus.ID)], + peerRoute.String(), + "expected route %s to be present on peer %s (%s) in %s (%s) status", + expectedRoutes[string(peerStatus.ID)], + peerStatus.HostName, + peerStatus.ID, + client.Hostname(), + client.ID(), + ) + } + } + } + + routeToBeDisabled := enablingRoutes[0] + log.Printf("preparing to disable %v", routeToBeDisabled) + + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "disable", + "--route", + strconv.Itoa(int(routeToBeDisabled.GetId())), + }) + assertNoErr(t, err) + + var disablingRoutes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &disablingRoutes, + ) + assertNoErr(t, err) + + for _, route := range disablingRoutes { + assert.Equal(t, true, route.GetAdvertised()) + + if route.GetId() == routeToBeDisabled.GetId() { + assert.Equal(t, route.GetEnabled(), false) + assert.Equal(t, route.GetIsPrimary(), false) + } else { + assert.Equal(t, route.GetEnabled(), true) + assert.Equal(t, route.GetIsPrimary(), true) + } + } + + time.Sleep(5 * time.Second) + + // Verify that the clients can see the new routes + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + if string(peerStatus.ID) == fmt.Sprintf("%d", routeToBeDisabled.GetNode().GetId()) { + assert.Nilf( + t, + peerStatus.PrimaryRoutes, + "expected node %s to have no routes, got primary route (%v)", + peerStatus.HostName, + peerStatus.PrimaryRoutes, + ) + } + } + } +} + +func TestHASubnetRouterFailover(t *testing.T) { + IntegrationSkip(t) + t.Parallel() + + user := "enable-routing" + + scenario, err := NewScenario() + assertNoErrf(t, "failed to create scenario: %s", err) + defer scenario.Shutdown() + + spec := map[string]int{ + user: 3, + } + + err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clienableroute")) + assertNoErrHeadscaleEnv(t, err) + + allClients, err := scenario.ListTailscaleClients() + assertNoErrListClients(t, err) + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + headscale, err := scenario.Headscale() + assertNoErrGetHeadscale(t, err) + + expectedRoutes := map[string]string{ + "1": "10.0.0.0/24", + "2": "10.0.0.0/24", + } + + // Sort nodes by ID + sort.SliceStable(allClients, func(i, j int) bool { + statusI, err := allClients[i].Status() + if err != nil { + return false + } + + statusJ, err := allClients[j].Status() + if err != nil { + return false + } + + return statusI.Self.ID < statusJ.Self.ID + }) + + subRouter1 := allClients[0] + subRouter2 := allClients[1] + + client := allClients[2] + + // advertise HA route on node 1 and 2 + // ID 1 will be primary + // ID 2 will be secondary + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + if route, ok := expectedRoutes[string(status.Self.ID)]; ok { + command := []string{ + "tailscale", + "set", + "--advertise-routes=" + route, + } + _, _, err = client.Execute(command) + assertNoErrf(t, "failed to advertise route: %s", err) + } + } + + err = scenario.WaitForTailscaleSync() + assertNoErrSync(t, err) + + var routes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routes, + ) + + assertNoErr(t, err) + assert.Len(t, routes, 2) + + for _, route := range routes { + assert.Equal(t, true, route.GetAdvertised()) + assert.Equal(t, false, route.GetEnabled()) + assert.Equal(t, false, route.GetIsPrimary()) + } + + // Verify that no routes has been sent to the client, + // they are not yet enabled. + for _, client := range allClients { + status, err := client.Status() + assertNoErr(t, err) + + for _, peerKey := range status.Peers() { + peerStatus := status.Peer[peerKey] + + assert.Nil(t, peerStatus.PrimaryRoutes) + } + } + + // Enable all routes + for _, route := range routes { + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "enable", + "--route", + strconv.Itoa(int(route.GetId())), + }) + assertNoErr(t, err) + + time.Sleep(time.Second) + } + + var enablingRoutes []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &enablingRoutes, + ) + assertNoErr(t, err) + assert.Len(t, enablingRoutes, 2) + + // Node 1 is primary + assert.Equal(t, true, enablingRoutes[0].GetAdvertised()) + assert.Equal(t, true, enablingRoutes[0].GetEnabled()) + assert.Equal(t, true, enablingRoutes[0].GetIsPrimary()) + + // Node 2 is not primary + assert.Equal(t, true, enablingRoutes[1].GetAdvertised()) + assert.Equal(t, true, enablingRoutes[1].GetEnabled()) + assert.Equal(t, false, enablingRoutes[1].GetIsPrimary()) + + // Verify that the client has routes from the primary machine + srs1, err := subRouter1.Status() + srs2, err := subRouter2.Status() + + clientStatus, err := client.Status() + assertNoErr(t, err) + + srs1PeerStatus := clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus := clientStatus.Peer[srs2.Self.PublicKey] + + assertNotNil(t, srs1PeerStatus.PrimaryRoutes) + assert.Nil(t, srs2PeerStatus.PrimaryRoutes) + + assert.Contains( + t, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), + ) + + // Take down the current primary + t.Logf("taking down subnet router 1 (%s)", subRouter1.Hostname()) + err = subRouter1.Down() + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + var routesAfterMove []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfterMove, + ) + assertNoErr(t, err) + assert.Len(t, routesAfterMove, 2) + + // Node 1 is not primary + assert.Equal(t, true, routesAfterMove[0].GetAdvertised()) + assert.Equal(t, true, routesAfterMove[0].GetEnabled()) + assert.Equal(t, false, routesAfterMove[0].GetIsPrimary()) + + // Node 2 is primary + assert.Equal(t, true, routesAfterMove[1].GetAdvertised()) + assert.Equal(t, true, routesAfterMove[1].GetEnabled()) + assert.Equal(t, true, routesAfterMove[1].GetIsPrimary()) + + // TODO(kradalby): Check client status + // Route is expected to be on SR2 + + srs2, err = subRouter2.Status() + + clientStatus, err = client.Status() + assertNoErr(t, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + + assert.Nil(t, srs1PeerStatus.PrimaryRoutes) + assertNotNil(t, srs2PeerStatus.PrimaryRoutes) + + if srs2PeerStatus.PrimaryRoutes != nil { + assert.Contains( + t, + srs2PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), + ) + } + + // Take down subnet router 2, leaving none available + t.Logf("taking down subnet router 2 (%s)", subRouter2.Hostname()) + err = subRouter2.Down() + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + var routesAfterBothDown []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfterBothDown, + ) + assertNoErr(t, err) + assert.Len(t, routesAfterBothDown, 2) + + // Node 1 is not primary + assert.Equal(t, true, routesAfterBothDown[0].GetAdvertised()) + assert.Equal(t, true, routesAfterBothDown[0].GetEnabled()) + assert.Equal(t, false, routesAfterBothDown[0].GetIsPrimary()) + + // Node 2 is primary + // if the node goes down, but no other suitable route is + // available, keep the last known good route. + assert.Equal(t, true, routesAfterBothDown[1].GetAdvertised()) + assert.Equal(t, true, routesAfterBothDown[1].GetEnabled()) + assert.Equal(t, true, routesAfterBothDown[1].GetIsPrimary()) + + // TODO(kradalby): Check client status + // Both are expected to be down + + // Verify that the route is not presented from either router + clientStatus, err = client.Status() + assertNoErr(t, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + + assert.Nil(t, srs1PeerStatus.PrimaryRoutes) + assertNotNil(t, srs2PeerStatus.PrimaryRoutes) + + if srs2PeerStatus.PrimaryRoutes != nil { + assert.Contains( + t, + srs2PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), + ) + } + + // Bring up subnet router 1, making the route available from there. + t.Logf("bringing up subnet router 1 (%s)", subRouter1.Hostname()) + err = subRouter1.Up() + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + var routesAfter1Up []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfter1Up, + ) + assertNoErr(t, err) + assert.Len(t, routesAfter1Up, 2) + + // Node 1 is primary + assert.Equal(t, true, routesAfter1Up[0].GetAdvertised()) + assert.Equal(t, true, routesAfter1Up[0].GetEnabled()) + assert.Equal(t, true, routesAfter1Up[0].GetIsPrimary()) + + // Node 2 is not primary + assert.Equal(t, true, routesAfter1Up[1].GetAdvertised()) + assert.Equal(t, true, routesAfter1Up[1].GetEnabled()) + assert.Equal(t, false, routesAfter1Up[1].GetIsPrimary()) + + // Verify that the route is announced from subnet router 1 + clientStatus, err = client.Status() + assertNoErr(t, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + + assert.NotNil(t, srs1PeerStatus.PrimaryRoutes) + assert.Nil(t, srs2PeerStatus.PrimaryRoutes) + + if srs1PeerStatus.PrimaryRoutes != nil { + assert.Contains( + t, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), + ) + } + + // Bring up subnet router 2, should result in no change. + t.Logf("bringing up subnet router 2 (%s)", subRouter2.Hostname()) + err = subRouter2.Up() + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + var routesAfter2Up []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfter2Up, + ) + assertNoErr(t, err) + assert.Len(t, routesAfter2Up, 2) + + // Node 1 is not primary + assert.Equal(t, true, routesAfter2Up[0].GetAdvertised()) + assert.Equal(t, true, routesAfter2Up[0].GetEnabled()) + assert.Equal(t, true, routesAfter2Up[0].GetIsPrimary()) + + // Node 2 is primary + assert.Equal(t, true, routesAfter2Up[1].GetAdvertised()) + assert.Equal(t, true, routesAfter2Up[1].GetEnabled()) + assert.Equal(t, false, routesAfter2Up[1].GetIsPrimary()) + + // Verify that the route is announced from subnet router 1 + clientStatus, err = client.Status() + assertNoErr(t, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + + assert.NotNil(t, srs1PeerStatus.PrimaryRoutes) + assert.Nil(t, srs2PeerStatus.PrimaryRoutes) + + if srs1PeerStatus.PrimaryRoutes != nil { + assert.Contains( + t, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), + ) + } + + // Disable the route of subnet router 1, making it failover to 2 + t.Logf("disabling route in subnet router 1 (%s)", subRouter1.Hostname()) + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "disable", + "--route", + fmt.Sprintf("%d", routesAfter2Up[0].GetId()), + }) + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + var routesAfterDisabling1 []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfterDisabling1, + ) + assertNoErr(t, err) + assert.Len(t, routesAfterDisabling1, 2) + + // Node 1 is not primary + assert.Equal(t, true, routesAfterDisabling1[0].GetAdvertised()) + assert.Equal(t, false, routesAfterDisabling1[0].GetEnabled()) + assert.Equal(t, false, routesAfterDisabling1[0].GetIsPrimary()) + + // Node 2 is primary + assert.Equal(t, true, routesAfterDisabling1[1].GetAdvertised()) + assert.Equal(t, true, routesAfterDisabling1[1].GetEnabled()) + assert.Equal(t, true, routesAfterDisabling1[1].GetIsPrimary()) + + // Verify that the route is announced from subnet router 1 + clientStatus, err = client.Status() + assertNoErr(t, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + + assert.Nil(t, srs1PeerStatus.PrimaryRoutes) + assert.NotNil(t, srs2PeerStatus.PrimaryRoutes) + + if srs2PeerStatus.PrimaryRoutes != nil { + assert.Contains( + t, + srs2PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), + ) + } + + // enable the route of subnet router 1, no change expected + t.Logf("enabling route in subnet router 1 (%s)", subRouter1.Hostname()) + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "enable", + "--route", + fmt.Sprintf("%d", routesAfter2Up[0].GetId()), + }) + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + var routesAfterEnabling1 []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfterEnabling1, + ) + assertNoErr(t, err) + assert.Len(t, routesAfterEnabling1, 2) + + // Node 1 is not primary + assert.Equal(t, true, routesAfterEnabling1[0].GetAdvertised()) + assert.Equal(t, true, routesAfterEnabling1[0].GetEnabled()) + assert.Equal(t, false, routesAfterEnabling1[0].GetIsPrimary()) + + // Node 2 is primary + assert.Equal(t, true, routesAfterEnabling1[1].GetAdvertised()) + assert.Equal(t, true, routesAfterEnabling1[1].GetEnabled()) + assert.Equal(t, true, routesAfterEnabling1[1].GetIsPrimary()) + + // Verify that the route is announced from subnet router 1 + clientStatus, err = client.Status() + assertNoErr(t, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + + assert.Nil(t, srs1PeerStatus.PrimaryRoutes) + assert.NotNil(t, srs2PeerStatus.PrimaryRoutes) + + if srs2PeerStatus.PrimaryRoutes != nil { + assert.Contains( + t, + srs2PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs2.Self.ID)]), + ) + } + + // delete the route of subnet router 2, failover to one expected + t.Logf("deleting route in subnet router 2 (%s)", subRouter2.Hostname()) + _, err = headscale.Execute( + []string{ + "headscale", + "routes", + "delete", + "--route", + fmt.Sprintf("%d", routesAfterEnabling1[1].GetId()), + }) + assertNoErr(t, err) + + time.Sleep(5 * time.Second) + + var routesAfterDeleting2 []*v1.Route + err = executeAndUnmarshal( + headscale, + []string{ + "headscale", + "routes", + "list", + "--output", + "json", + }, + &routesAfterDeleting2, + ) + assertNoErr(t, err) + assert.Len(t, routesAfterDeleting2, 1) + + t.Logf("routes after deleting2 %#v", routesAfterDeleting2) + + // Node 1 is primary + assert.Equal(t, true, routesAfterDeleting2[0].GetAdvertised()) + assert.Equal(t, true, routesAfterDeleting2[0].GetEnabled()) + assert.Equal(t, true, routesAfterDeleting2[0].GetIsPrimary()) + + // Verify that the route is announced from subnet router 1 + clientStatus, err = client.Status() + assertNoErr(t, err) + + srs1PeerStatus = clientStatus.Peer[srs1.Self.PublicKey] + srs2PeerStatus = clientStatus.Peer[srs2.Self.PublicKey] + + assertNotNil(t, srs1PeerStatus.PrimaryRoutes) + assert.Nil(t, srs2PeerStatus.PrimaryRoutes) + + if srs1PeerStatus.PrimaryRoutes != nil { + assert.Contains( + t, + srs1PeerStatus.PrimaryRoutes.AsSlice(), + netip.MustParsePrefix(expectedRoutes[string(srs1.Self.ID)]), + ) + } +} diff --git a/integration/run.sh b/integration/run.sh index b03338a5..8c1fb016 100755 --- a/integration/run.sh +++ b/integration/run.sh @@ -29,7 +29,7 @@ run_tests() { -failfast \ -timeout 120m \ -parallel 1 \ - -run "^$test_name\$" >/dev/null 2>&1 + -run "^$test_name\$" >./control_logs/"$test_name"_"$i".log 2>&1 status=$? end=$(date +%s) diff --git a/integration/scenario.go b/integration/scenario.go index 5e0ccf14..6bcd5852 100644 --- a/integration/scenario.go +++ b/integration/scenario.go @@ -15,6 +15,7 @@ import ( "github.com/juanfont/headscale/integration/tsic" "github.com/ory/dockertest/v3" "github.com/puzpuzpuz/xsync/v3" + "github.com/samber/lo" "golang.org/x/sync/errgroup" ) @@ -93,7 +94,7 @@ var ( // // - Two unstable (HEAD and unstable) // - Two latest versions - // - Two oldest versions. + // - Two oldest supported version. MustTestVersions = append( AllVersions[0:4], AllVersions[len(AllVersions)-2:]..., @@ -296,11 +297,13 @@ func (s *Scenario) CreateTailscaleNodesInUser( opts ...tsic.Option, ) error { if user, ok := s.users[userStr]; ok { + var versions []string for i := 0; i < count; i++ { version := requestedVersion if requestedVersion == "all" { version = MustTestVersions[i%len(MustTestVersions)] } + versions = append(versions, version) headscale, err := s.Headscale() if err != nil { @@ -350,6 +353,8 @@ func (s *Scenario) CreateTailscaleNodesInUser( return err } + log.Printf("testing versions %v", lo.Uniq(versions)) + return nil } @@ -403,7 +408,17 @@ func (s *Scenario) CountTailscale() int { func (s *Scenario) WaitForTailscaleSync() error { tsCount := s.CountTailscale() - return s.WaitForTailscaleSyncWithPeerCount(tsCount - 1) + err := s.WaitForTailscaleSyncWithPeerCount(tsCount - 1) + if err != nil { + for _, user := range s.users { + for _, client := range user.Clients { + peers, _ := client.PrettyPeers() + log.Println(peers) + } + } + } + + return err } // WaitForTailscaleSyncWithPeerCount blocks execution until all the TailscaleClient reports diff --git a/integration/ssh_test.go b/integration/ssh_test.go index 88e62e9d..587190e4 100644 --- a/integration/ssh_test.go +++ b/integration/ssh_test.go @@ -109,7 +109,7 @@ func TestSSHOneUserToAll(t *testing.T) { }, }, }, - len(MustTestVersions)-2, + len(MustTestVersions), ) defer scenario.Shutdown() @@ -174,7 +174,7 @@ func TestSSHMultipleUsersAllToAll(t *testing.T) { }, }, }, - len(MustTestVersions)-2, + len(MustTestVersions), ) defer scenario.Shutdown() @@ -220,7 +220,7 @@ func TestSSHNoSSHConfigured(t *testing.T) { }, SSHs: []policy.SSH{}, }, - len(MustTestVersions)-2, + len(MustTestVersions), ) defer scenario.Shutdown() @@ -269,7 +269,7 @@ func TestSSHIsBlockedInACL(t *testing.T) { }, }, }, - len(MustTestVersions)-2, + len(MustTestVersions), ) defer scenario.Shutdown() @@ -325,7 +325,7 @@ func TestSSHUserOnlyIsolation(t *testing.T) { }, }, }, - len(MustTestVersions)-2, + len(MustTestVersions), ) defer scenario.Shutdown() diff --git a/integration/tailscale.go b/integration/tailscale.go index ba63e7d6..e7bf71b9 100644 --- a/integration/tailscale.go +++ b/integration/tailscale.go @@ -21,6 +21,8 @@ type TailscaleClient interface { Login(loginServer, authKey string) error LoginWithURL(loginServer string) (*url.URL, error) Logout() error + Up() error + Down() error IPs() ([]netip.Addr, error) FQDN() (string, error) Status() (*ipnstate.Status, error) @@ -30,4 +32,5 @@ type TailscaleClient interface { Ping(hostnameOrIP string, opts ...tsic.PingOption) error Curl(url string, opts ...tsic.CurlOption) (string, error) ID() string + PrettyPeers() (string, error) } diff --git a/integration/tsic/tsic.go b/integration/tsic/tsic.go index efe9c904..7404f6ea 100644 --- a/integration/tsic/tsic.go +++ b/integration/tsic/tsic.go @@ -285,6 +285,15 @@ func (t *TailscaleInContainer) hasTLS() bool { // Shutdown stops and cleans up the Tailscale container. func (t *TailscaleInContainer) Shutdown() error { + err := t.SaveLog("/tmp/control") + if err != nil { + log.Printf( + "Failed to save log from %s: %s", + t.hostname, + fmt.Errorf("failed to save log: %w", err), + ) + } + return t.pool.Purge(t.container) } @@ -417,6 +426,44 @@ func (t *TailscaleInContainer) Logout() error { return nil } +// Helper that runs `tailscale up` with no arguments. +func (t *TailscaleInContainer) Up() error { + command := []string{ + "tailscale", + "up", + } + + if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil { + return fmt.Errorf( + "%s failed to bring tailscale client up (%s): %w", + t.hostname, + strings.Join(command, " "), + err, + ) + } + + return nil +} + +// Helper that runs `tailscale down` with no arguments. +func (t *TailscaleInContainer) Down() error { + command := []string{ + "tailscale", + "down", + } + + if _, _, err := t.Execute(command, dockertestutil.ExecuteCommandTimeout(dockerExecuteTimeout)); err != nil { + return fmt.Errorf( + "%s failed to bring tailscale client down (%s): %w", + t.hostname, + strings.Join(command, " "), + err, + ) + } + + return nil +} + // IPs returns the netip.Addr of the Tailscale instance. func (t *TailscaleInContainer) IPs() ([]netip.Addr, error) { if t.ips != nil && len(t.ips) != 0 { @@ -486,6 +533,34 @@ func (t *TailscaleInContainer) FQDN() (string, error) { return status.Self.DNSName, nil } +// PrettyPeers returns a formatted-ish table of peers in the client. +func (t *TailscaleInContainer) PrettyPeers() (string, error) { + status, err := t.Status() + if err != nil { + return "", fmt.Errorf("failed to get FQDN: %w", err) + } + + str := fmt.Sprintf("Peers of %s\n", t.hostname) + str += "Hostname\tOnline\tLastSeen\n" + + peerCount := len(status.Peers()) + onlineCount := 0 + + for _, peerKey := range status.Peers() { + peer := status.Peer[peerKey] + + if peer.Online { + onlineCount++ + } + + str += fmt.Sprintf("%s\t%t\t%s\n", peer.HostName, peer.Online, peer.LastSeen) + } + + str += fmt.Sprintf("Peer Count: %d, Online Count: %d\n\n", peerCount, onlineCount) + + return str, nil +} + // WaitForNeedsLogin blocks until the Tailscale (tailscaled) instance has // started and needs to be logged into. func (t *TailscaleInContainer) WaitForNeedsLogin() error { @@ -531,7 +606,7 @@ func (t *TailscaleInContainer) WaitForRunning() error { } // WaitForPeers blocks until N number of peers is present in the -// Peer list of the Tailscale instance. +// Peer list of the Tailscale instance and is reporting Online. func (t *TailscaleInContainer) WaitForPeers(expected int) error { return t.pool.Retry(func() error { status, err := t.Status() @@ -547,6 +622,14 @@ func (t *TailscaleInContainer) WaitForPeers(expected int) error { expected, len(peers), ) + } else { + for _, peerKey := range peers { + peer := status.Peer[peerKey] + + if !peer.Online { + return fmt.Errorf("[%s] peer count correct, but %s is not online", t.hostname, peer.HostName) + } + } } return nil @@ -738,3 +821,9 @@ func (t *TailscaleInContainer) Curl(url string, opts ...CurlOption) (string, err func (t *TailscaleInContainer) WriteFile(path string, data []byte) error { return integrationutil.WriteFileToContainer(t.pool, t.container, path, data) } + +// SaveLog saves the current stdout log of the container to a path +// on the host system. +func (t *TailscaleInContainer) SaveLog(path string) error { + return dockertestutil.SaveLog(t.pool, t.container, path) +} diff --git a/integration/utils.go b/integration/utils.go index 91e274b1..e17e18a2 100644 --- a/integration/utils.go +++ b/integration/utils.go @@ -26,6 +26,13 @@ func assertNoErrf(t *testing.T, msg string, err error) { } } +func assertNotNil(t *testing.T, thing interface{}) { + t.Helper() + if thing == nil { + t.Fatal("got unexpected nil") + } +} + func assertNoErrHeadscaleEnv(t *testing.T, err error) { t.Helper() assertNoErrf(t, "failed to create headscale environment: %s", err) @@ -68,13 +75,13 @@ func assertContains(t *testing.T, str, subStr string) { } } -func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string) int { +func pingAllHelper(t *testing.T, clients []TailscaleClient, addrs []string, opts ...tsic.PingOption) int { t.Helper() success := 0 for _, client := range clients { for _, addr := range addrs { - err := client.Ping(addr) + err := client.Ping(addr, opts...) if err != nil { t.Fatalf("failed to ping %s from %s: %s", addr, client.Hostname(), err) } else {