Compare commits
6 Commits
fa4400b736
...
ac49825cbf
Author | SHA1 | Date |
---|---|---|
Rorical | ac49825cbf | |
Kristoffer Dalby | 622aa82da2 | |
Kristoffer Dalby | a9c568c801 | |
Kristoffer Dalby | 1c6bfc503c | |
Kristoffer Dalby | 55b35f4160 | |
Rorical | 38c148745a |
|
@ -26,6 +26,7 @@ jobs:
|
|||
- TestPreAuthKeyCommand
|
||||
- TestPreAuthKeyCommandWithoutExpiry
|
||||
- TestPreAuthKeyCommandReusableEphemeral
|
||||
- TestPreAuthKeyCorrectUserLoggedInCommand
|
||||
- TestApiKeyCommand
|
||||
- TestNodeTagCommand
|
||||
- TestNodeAdvertiseTagNoACLCommand
|
||||
|
|
134
hscontrol/app.go
134
hscontrol/app.go
|
@ -70,7 +70,7 @@ var (
|
|||
|
||||
const (
|
||||
AuthPrefix = "Bearer "
|
||||
updateInterval = 5000
|
||||
updateInterval = 5 * time.Second
|
||||
privateKeyFileMode = 0o600
|
||||
headscaleDirPerm = 0o700
|
||||
|
||||
|
@ -219,64 +219,75 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
|
|||
|
||||
// deleteExpireEphemeralNodes deletes ephemeral node records that have not been
|
||||
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
|
||||
func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) {
|
||||
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
|
||||
func (h *Headscale) deleteExpireEphemeralNodes(ctx context.Context, every time.Duration) {
|
||||
ticker := time.NewTicker(every)
|
||||
|
||||
for range ticker.C {
|
||||
var removed []types.NodeID
|
||||
var changed []types.NodeID
|
||||
if err := h.db.Write(func(tx *gorm.DB) error {
|
||||
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
var removed []types.NodeID
|
||||
var changed []types.NodeID
|
||||
if err := h.db.Write(func(tx *gorm.DB) error {
|
||||
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
|
||||
continue
|
||||
}
|
||||
|
||||
if removed != nil {
|
||||
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||
Type: types.StatePeerRemoved,
|
||||
Removed: removed,
|
||||
})
|
||||
}
|
||||
if removed != nil {
|
||||
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||
Type: types.StatePeerRemoved,
|
||||
Removed: removed,
|
||||
})
|
||||
}
|
||||
|
||||
if changed != nil {
|
||||
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: changed,
|
||||
})
|
||||
if changed != nil {
|
||||
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
|
||||
Type: types.StatePeerChanged,
|
||||
ChangeNodes: changed,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// expireExpiredMachines expires nodes that have an explicit expiry set
|
||||
// expireExpiredNodes expires nodes that have an explicit expiry set
|
||||
// after that expiry time has passed.
|
||||
func (h *Headscale) expireExpiredMachines(intervalMs int64) {
|
||||
interval := time.Duration(intervalMs) * time.Millisecond
|
||||
ticker := time.NewTicker(interval)
|
||||
func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) {
|
||||
ticker := time.NewTicker(every)
|
||||
|
||||
lastCheck := time.Unix(0, 0)
|
||||
var update types.StateUpdate
|
||||
var changed bool
|
||||
|
||||
for range ticker.C {
|
||||
if err := h.db.Write(func(tx *gorm.DB) error {
|
||||
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
ticker.Stop()
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := h.db.Write(func(tx *gorm.DB) error {
|
||||
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("database error while expiring nodes")
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("database error while expiring nodes")
|
||||
continue
|
||||
}
|
||||
|
||||
if changed {
|
||||
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
|
||||
if changed {
|
||||
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, update)
|
||||
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
|
||||
h.nodeNotifier.NotifyAll(ctx, update)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -538,10 +549,13 @@ func (h *Headscale) Serve() error {
|
|||
return errEmptyInitialDERPMap
|
||||
}
|
||||
|
||||
// TODO(kradalby): These should have cancel channels and be cleaned
|
||||
// up on shutdown.
|
||||
go h.deleteExpireEphemeralNodes(updateInterval)
|
||||
go h.expireExpiredMachines(updateInterval)
|
||||
expireEphemeralCtx, expireEphemeralCancel := context.WithCancel(context.Background())
|
||||
defer expireEphemeralCancel()
|
||||
go h.deleteExpireEphemeralNodes(expireEphemeralCtx, updateInterval)
|
||||
|
||||
expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background())
|
||||
defer expireNodeCancel()
|
||||
go h.expireExpiredNodes(expireNodeCtx, updateInterval)
|
||||
|
||||
if zl.GlobalLevel() == zl.TraceLevel {
|
||||
zerolog.RespLog = true
|
||||
|
@ -800,10 +814,26 @@ func (h *Headscale) Serve() error {
|
|||
}
|
||||
|
||||
default:
|
||||
trace := log.Trace().Msgf
|
||||
log.Info().
|
||||
Str("signal", sig.String()).
|
||||
Msg("Received signal to stop, shutting down gracefully")
|
||||
|
||||
expireNodeCancel()
|
||||
expireEphemeralCancel()
|
||||
|
||||
trace("closing map sessions")
|
||||
wg := sync.WaitGroup{}
|
||||
for _, mapSess := range h.mapSessions {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
mapSess.close()
|
||||
wg.Done()
|
||||
}()
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
trace("waiting for netmap stream to close")
|
||||
h.pollNetMapStreamWG.Wait()
|
||||
|
||||
// Gracefully shut down servers
|
||||
|
@ -811,32 +841,44 @@ func (h *Headscale) Serve() error {
|
|||
context.Background(),
|
||||
types.HTTPShutdownTimeout,
|
||||
)
|
||||
trace("shutting down debug http server")
|
||||
if err := debugHTTPServer.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to shutdown prometheus http")
|
||||
}
|
||||
trace("shutting down main http server")
|
||||
if err := httpServer.Shutdown(ctx); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to shutdown http")
|
||||
}
|
||||
|
||||
trace("shutting down grpc server (socket)")
|
||||
grpcSocket.GracefulStop()
|
||||
|
||||
if grpcServer != nil {
|
||||
trace("shutting down grpc server (external)")
|
||||
grpcServer.GracefulStop()
|
||||
grpcListener.Close()
|
||||
}
|
||||
|
||||
if tailsqlContext != nil {
|
||||
trace("shutting down tailsql")
|
||||
tailsqlContext.Done()
|
||||
}
|
||||
|
||||
trace("closing node notifier")
|
||||
h.nodeNotifier.Close()
|
||||
|
||||
// Close network listeners
|
||||
trace("closing network listeners")
|
||||
debugHTTPListener.Close()
|
||||
httpListener.Close()
|
||||
grpcGatewayConn.Close()
|
||||
|
||||
// Stop listening (and unlink the socket if unix type):
|
||||
trace("closing socket listener")
|
||||
socketListener.Close()
|
||||
|
||||
// Close db connections
|
||||
trace("closing database connection")
|
||||
err = h.db.Close()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to close db")
|
||||
|
|
|
@ -315,13 +315,16 @@ func (h *Headscale) handleAuthKey(
|
|||
|
||||
node.NodeKey = nodeKey
|
||||
node.AuthKeyID = uint(pak.ID)
|
||||
err := h.db.NodeSetExpiry(node.ID, registerRequest.Expiry)
|
||||
node.Expiry = ®isterRequest.Expiry
|
||||
node.User = pak.User
|
||||
node.UserID = pak.UserID
|
||||
err := h.db.DB.Save(node).Error
|
||||
if err != nil {
|
||||
log.Error().
|
||||
Caller().
|
||||
Str("node", node.Hostname).
|
||||
Err(err).
|
||||
Msg("Failed to refresh node")
|
||||
Msg("failed to save node after logging in with auth key")
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -344,7 +347,7 @@ func (h *Headscale) handleAuthKey(
|
|||
}
|
||||
|
||||
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
|
||||
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, registerRequest.Expiry), node.ID)
|
||||
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{Type: types.StatePeerChanged, ChangeNodes: []types.NodeID{node.ID}})
|
||||
} else {
|
||||
now := time.Now().UTC()
|
||||
|
||||
|
|
|
@ -34,6 +34,11 @@ func NewNotifier(cfg *types.Config) *Notifier {
|
|||
return n
|
||||
}
|
||||
|
||||
// Close stops the batcher inside the notifier.
|
||||
func (n *Notifier) Close() {
|
||||
n.b.close()
|
||||
}
|
||||
|
||||
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
|
||||
log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to add node")
|
||||
defer log.Trace().
|
||||
|
|
|
@ -133,19 +133,25 @@ func (h *Headscale) RegisterOIDC(
|
|||
|
||||
stateStr := hex.EncodeToString(randomBlob)[:32]
|
||||
|
||||
// place the node key into the state cache, so it can be retrieved later
|
||||
// generate PKCE code verifier
|
||||
verifier := oauth2.GenerateVerifier()
|
||||
// place the node key and verifier into the state cache, so it can be retrieved later
|
||||
h.registrationCache.Set(
|
||||
stateStr,
|
||||
machineKey,
|
||||
types.RegistrationInfo{
|
||||
MachineKey: machineKey,
|
||||
Verifier: verifier,
|
||||
},
|
||||
registerCacheExpiration,
|
||||
)
|
||||
|
||||
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
|
||||
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)+2)
|
||||
|
||||
for k, v := range h.cfg.OIDC.ExtraParams {
|
||||
extras = append(extras, oauth2.SetAuthURLParam(k, v))
|
||||
}
|
||||
extras = append(extras, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier))
|
||||
|
||||
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
|
||||
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
|
||||
|
@ -179,7 +185,33 @@ func (h *Headscale) OIDCCallback(
|
|||
return
|
||||
}
|
||||
|
||||
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
|
||||
regState, stateFound := h.registrationCache.Get(state)
|
||||
if !stateFound {
|
||||
log.Trace().
|
||||
Msg("requested state key expired before authorisation completed")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("state has expired"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
return
|
||||
}
|
||||
regInfo, regInfoOK := regState.(types.RegistrationInfo)
|
||||
if !regInfoOK {
|
||||
log.Trace().
|
||||
Interface("got", regInfo).
|
||||
Msg("requested state is not a RegistrationInfo")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("state is invalid"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state, regInfo)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
@ -216,7 +248,7 @@ func (h *Headscale) OIDCCallback(
|
|||
|
||||
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
|
||||
writer,
|
||||
state,
|
||||
regInfo,
|
||||
claims,
|
||||
idTokenExpiry,
|
||||
)
|
||||
|
@ -278,8 +310,9 @@ func (h *Headscale) getIDTokenForOIDCCallback(
|
|||
ctx context.Context,
|
||||
writer http.ResponseWriter,
|
||||
code, state string,
|
||||
regInfo types.RegistrationInfo,
|
||||
) (string, error) {
|
||||
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
|
||||
oauth2Token, err := h.oauth2Config.Exchange(ctx, code, oauth2.VerifierOption(regInfo.Verifier))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Could not exchange code for token")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
|
@ -441,46 +474,17 @@ func validateOIDCAllowedUsers(
|
|||
// on to registration.
|
||||
func (h *Headscale) validateNodeForOIDCCallback(
|
||||
writer http.ResponseWriter,
|
||||
state string,
|
||||
state types.RegistrationInfo,
|
||||
claims *IDTokenClaims,
|
||||
expiry time.Time,
|
||||
) (*key.MachinePublic, bool, error) {
|
||||
// retrieve nodekey from state cache
|
||||
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
|
||||
if !machineKeyFound {
|
||||
log.Trace().
|
||||
Msg("requested node state key expired before authorisation completed")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("state has expired"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return nil, false, errOIDCNodeKeyMissing
|
||||
}
|
||||
|
||||
var machineKey key.MachinePublic
|
||||
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
|
||||
if !machineKeyOK {
|
||||
log.Trace().
|
||||
Interface("got", machineKeyIf).
|
||||
Msg("requested node state key is not a nodekey")
|
||||
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
writer.WriteHeader(http.StatusBadRequest)
|
||||
_, err := writer.Write([]byte("state is invalid"))
|
||||
if err != nil {
|
||||
util.LogErr(err, "Failed to write response")
|
||||
}
|
||||
|
||||
return nil, false, errOIDCInvalidNodeState
|
||||
}
|
||||
|
||||
// retrieve node information if it exist
|
||||
// The error is not important, because if it does not
|
||||
// exist, then this is a new node and we will move
|
||||
// on to registration.
|
||||
node, _ := h.db.GetNodeByMachineKey(machineKey)
|
||||
node, _ := h.db.GetNodeByMachineKey(state.MachineKey)
|
||||
|
||||
if node != nil {
|
||||
log.Trace().
|
||||
|
@ -532,7 +536,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
|
|||
return nil, true, nil
|
||||
}
|
||||
|
||||
return &machineKey, false, nil
|
||||
return &state.MachineKey, false, nil
|
||||
}
|
||||
|
||||
func getUserName(
|
||||
|
|
|
@ -306,11 +306,15 @@ func (node *Node) AfterFind(tx *gorm.DB) error {
|
|||
}
|
||||
node.NodeKey = nodeKey
|
||||
|
||||
var discoKey key.DiscoPublic
|
||||
if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil {
|
||||
return fmt.Errorf("unmarshalling disco key from db: %w", err)
|
||||
// DiscoKey might be empty if a node has not sent it to headscale.
|
||||
// This means that this might fail if the disco key is empty.
|
||||
if node.DiscoKeyDatabaseField != "" {
|
||||
var discoKey key.DiscoPublic
|
||||
if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil {
|
||||
return fmt.Errorf("unmarshalling disco key from db: %w", err)
|
||||
}
|
||||
node.DiscoKey = discoKey
|
||||
}
|
||||
node.DiscoKey = discoKey
|
||||
|
||||
endpoints := make([]netip.AddrPort, len(node.EndpointsDatabaseField))
|
||||
for idx, ep := range node.EndpointsDatabaseField {
|
||||
|
@ -530,3 +534,9 @@ func (nodes Nodes) IDMap() map[NodeID]*Node {
|
|||
|
||||
return ret
|
||||
}
|
||||
|
||||
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
|
||||
type RegistrationInfo struct {
|
||||
MachineKey key.MachinePublic
|
||||
Verifier string
|
||||
}
|
||||
|
|
|
@ -388,6 +388,101 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
|
|||
assert.Len(t, listedPreAuthKeys, 3)
|
||||
}
|
||||
|
||||
func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
||||
user1 := "user1"
|
||||
user2 := "user2"
|
||||
|
||||
scenario, err := NewScenario(dockertestMaxWait())
|
||||
assertNoErr(t, err)
|
||||
defer scenario.Shutdown()
|
||||
|
||||
spec := map[string]int{
|
||||
user1: 1,
|
||||
user2: 0,
|
||||
}
|
||||
|
||||
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak"))
|
||||
assertNoErr(t, err)
|
||||
|
||||
headscale, err := scenario.Headscale()
|
||||
assertNoErr(t, err)
|
||||
|
||||
var user2Key v1.PreAuthKey
|
||||
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"preauthkeys",
|
||||
"--user",
|
||||
user2,
|
||||
"create",
|
||||
"--reusable",
|
||||
"--expiration",
|
||||
"24h",
|
||||
"--output",
|
||||
"json",
|
||||
"--tags",
|
||||
"tag:test1,tag:test2",
|
||||
},
|
||||
&user2Key,
|
||||
)
|
||||
assertNoErr(t, err)
|
||||
|
||||
allClients, err := scenario.ListTailscaleClients()
|
||||
assertNoErrListClients(t, err)
|
||||
|
||||
assert.Len(t, allClients, 1)
|
||||
|
||||
client := allClients[0]
|
||||
|
||||
// Log out from user1
|
||||
err = client.Logout()
|
||||
assertNoErr(t, err)
|
||||
|
||||
err = scenario.WaitForTailscaleLogout()
|
||||
assertNoErr(t, err)
|
||||
|
||||
status, err := client.Status()
|
||||
assertNoErr(t, err)
|
||||
if status.BackendState == "Starting" || status.BackendState == "Running" {
|
||||
t.Fatalf("expected node to be logged out, backend state: %s", status.BackendState)
|
||||
}
|
||||
|
||||
err = client.Login(headscale.GetEndpoint(), user2Key.GetKey())
|
||||
assertNoErr(t, err)
|
||||
|
||||
status, err = client.Status()
|
||||
assertNoErr(t, err)
|
||||
if status.BackendState != "Running" {
|
||||
t.Fatalf("expected node to be logged in, backend state: %s", status.BackendState)
|
||||
}
|
||||
|
||||
if status.Self.UserID.String() != "userid:2" {
|
||||
t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String())
|
||||
}
|
||||
|
||||
var listNodes []v1.Node
|
||||
err = executeAndUnmarshal(
|
||||
headscale,
|
||||
[]string{
|
||||
"headscale",
|
||||
"nodes",
|
||||
"list",
|
||||
"--output",
|
||||
"json",
|
||||
},
|
||||
&listNodes,
|
||||
)
|
||||
assert.Nil(t, err)
|
||||
assert.Len(t, listNodes, 1)
|
||||
|
||||
assert.Equal(t, "user2", listNodes[0].User.Name)
|
||||
}
|
||||
|
||||
func TestApiKeyCommand(t *testing.T) {
|
||||
IntegrationSkip(t)
|
||||
t.Parallel()
|
||||
|
|
Loading…
Reference in New Issue