Compare commits

...

6 Commits

Author SHA1 Message Date
Rorical ac49825cbf
Merge 38c148745a into 622aa82da2 2024-05-03 08:44:27 +05:30
Kristoffer Dalby 622aa82da2
ensure expire routines are cleaned up (#1924)
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-05-02 15:57:53 +00:00
Kristoffer Dalby a9c568c801
trace log and notifier shutdown (#1922)
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-05-02 13:39:19 +02:00
Kristoffer Dalby 1c6bfc503c
fix preauth key logging in as previous user (#1920)
* add test case to reproduce #1885

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* fix preauth key issue logging in as wrong user

Fixes #1885

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* add test to gh

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-05-02 11:53:16 +02:00
Kristoffer Dalby 55b35f4160
fix issue preveting get node when disco is missing (#1919)
Fixed #1816

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-05-01 17:06:42 +02:00
Rorical 38c148745a add PKCE verifier for OIDC 2024-03-05 09:13:23 +01:00
7 changed files with 251 additions and 91 deletions

View File

@ -26,6 +26,7 @@ jobs:
- TestPreAuthKeyCommand
- TestPreAuthKeyCommandWithoutExpiry
- TestPreAuthKeyCommandReusableEphemeral
- TestPreAuthKeyCorrectUserLoggedInCommand
- TestApiKeyCommand
- TestNodeTagCommand
- TestNodeAdvertiseTagNoACLCommand

View File

@ -70,7 +70,7 @@ var (
const (
AuthPrefix = "Bearer "
updateInterval = 5000
updateInterval = 5 * time.Second
privateKeyFileMode = 0o600
headscaleDirPerm = 0o700
@ -219,64 +219,75 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
// deleteExpireEphemeralNodes deletes ephemeral node records that have not been
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
func (h *Headscale) deleteExpireEphemeralNodes(ctx context.Context, every time.Duration) {
ticker := time.NewTicker(every)
for range ticker.C {
var removed []types.NodeID
var changed []types.NodeID
if err := h.db.Write(func(tx *gorm.DB) error {
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
var removed []types.NodeID
var changed []types.NodeID
if err := h.db.Write(func(tx *gorm.DB) error {
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
continue
}
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
continue
}
if removed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: removed,
})
}
if removed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: removed,
})
}
if changed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changed,
})
if changed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changed,
})
}
}
}
}
// expireExpiredMachines expires nodes that have an explicit expiry set
// expireExpiredNodes expires nodes that have an explicit expiry set
// after that expiry time has passed.
func (h *Headscale) expireExpiredMachines(intervalMs int64) {
interval := time.Duration(intervalMs) * time.Millisecond
ticker := time.NewTicker(interval)
func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) {
ticker := time.NewTicker(every)
lastCheck := time.Unix(0, 0)
var update types.StateUpdate
var changed bool
for range ticker.C {
if err := h.db.Write(func(tx *gorm.DB) error {
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
if err := h.db.Write(func(tx *gorm.DB) error {
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring nodes")
continue
}
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring nodes")
continue
}
if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
}
}
}
}
@ -538,10 +549,13 @@ func (h *Headscale) Serve() error {
return errEmptyInitialDERPMap
}
// TODO(kradalby): These should have cancel channels and be cleaned
// up on shutdown.
go h.deleteExpireEphemeralNodes(updateInterval)
go h.expireExpiredMachines(updateInterval)
expireEphemeralCtx, expireEphemeralCancel := context.WithCancel(context.Background())
defer expireEphemeralCancel()
go h.deleteExpireEphemeralNodes(expireEphemeralCtx, updateInterval)
expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background())
defer expireNodeCancel()
go h.expireExpiredNodes(expireNodeCtx, updateInterval)
if zl.GlobalLevel() == zl.TraceLevel {
zerolog.RespLog = true
@ -800,10 +814,26 @@ func (h *Headscale) Serve() error {
}
default:
trace := log.Trace().Msgf
log.Info().
Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully")
expireNodeCancel()
expireEphemeralCancel()
trace("closing map sessions")
wg := sync.WaitGroup{}
for _, mapSess := range h.mapSessions {
wg.Add(1)
go func() {
mapSess.close()
wg.Done()
}()
}
wg.Wait()
trace("waiting for netmap stream to close")
h.pollNetMapStreamWG.Wait()
// Gracefully shut down servers
@ -811,32 +841,44 @@ func (h *Headscale) Serve() error {
context.Background(),
types.HTTPShutdownTimeout,
)
trace("shutting down debug http server")
if err := debugHTTPServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Failed to shutdown prometheus http")
}
trace("shutting down main http server")
if err := httpServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Failed to shutdown http")
}
trace("shutting down grpc server (socket)")
grpcSocket.GracefulStop()
if grpcServer != nil {
trace("shutting down grpc server (external)")
grpcServer.GracefulStop()
grpcListener.Close()
}
if tailsqlContext != nil {
trace("shutting down tailsql")
tailsqlContext.Done()
}
trace("closing node notifier")
h.nodeNotifier.Close()
// Close network listeners
trace("closing network listeners")
debugHTTPListener.Close()
httpListener.Close()
grpcGatewayConn.Close()
// Stop listening (and unlink the socket if unix type):
trace("closing socket listener")
socketListener.Close()
// Close db connections
trace("closing database connection")
err = h.db.Close()
if err != nil {
log.Error().Err(err).Msg("Failed to close db")

View File

@ -315,13 +315,16 @@ func (h *Headscale) handleAuthKey(
node.NodeKey = nodeKey
node.AuthKeyID = uint(pak.ID)
err := h.db.NodeSetExpiry(node.ID, registerRequest.Expiry)
node.Expiry = &registerRequest.Expiry
node.User = pak.User
node.UserID = pak.UserID
err := h.db.DB.Save(node).Error
if err != nil {
log.Error().
Caller().
Str("node", node.Hostname).
Err(err).
Msg("Failed to refresh node")
Msg("failed to save node after logging in with auth key")
return
}
@ -344,7 +347,7 @@ func (h *Headscale) handleAuthKey(
}
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, registerRequest.Expiry), node.ID)
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{Type: types.StatePeerChanged, ChangeNodes: []types.NodeID{node.ID}})
} else {
now := time.Now().UTC()

View File

@ -34,6 +34,11 @@ func NewNotifier(cfg *types.Config) *Notifier {
return n
}
// Close stops the batcher inside the notifier.
func (n *Notifier) Close() {
n.b.close()
}
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to add node")
defer log.Trace().

View File

@ -133,19 +133,25 @@ func (h *Headscale) RegisterOIDC(
stateStr := hex.EncodeToString(randomBlob)[:32]
// place the node key into the state cache, so it can be retrieved later
// generate PKCE code verifier
verifier := oauth2.GenerateVerifier()
// place the node key and verifier into the state cache, so it can be retrieved later
h.registrationCache.Set(
stateStr,
machineKey,
types.RegistrationInfo{
MachineKey: machineKey,
Verifier: verifier,
},
registerCacheExpiration,
)
// Add any extra parameter provided in the configuration to the Authorize Endpoint request
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams))
extras := make([]oauth2.AuthCodeOption, 0, len(h.cfg.OIDC.ExtraParams)+2)
for k, v := range h.cfg.OIDC.ExtraParams {
extras = append(extras, oauth2.SetAuthURLParam(k, v))
}
extras = append(extras, oauth2.AccessTypeOffline, oauth2.S256ChallengeOption(verifier))
authURL := h.oauth2Config.AuthCodeURL(stateStr, extras...)
log.Debug().Msgf("Redirecting to %s for authentication", authURL)
@ -179,7 +185,33 @@ func (h *Headscale) OIDCCallback(
return
}
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state)
regState, stateFound := h.registrationCache.Get(state)
if !stateFound {
log.Trace().
Msg("requested state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
regInfo, regInfoOK := regState.(types.RegistrationInfo)
if !regInfoOK {
log.Trace().
Interface("got", regInfo).
Msg("requested state is not a RegistrationInfo")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return
}
rawIDToken, err := h.getIDTokenForOIDCCallback(req.Context(), writer, code, state, regInfo)
if err != nil {
return
}
@ -216,7 +248,7 @@ func (h *Headscale) OIDCCallback(
machineKey, nodeExists, err := h.validateNodeForOIDCCallback(
writer,
state,
regInfo,
claims,
idTokenExpiry,
)
@ -278,8 +310,9 @@ func (h *Headscale) getIDTokenForOIDCCallback(
ctx context.Context,
writer http.ResponseWriter,
code, state string,
regInfo types.RegistrationInfo,
) (string, error) {
oauth2Token, err := h.oauth2Config.Exchange(ctx, code)
oauth2Token, err := h.oauth2Config.Exchange(ctx, code, oauth2.VerifierOption(regInfo.Verifier))
if err != nil {
util.LogErr(err, "Could not exchange code for token")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
@ -441,46 +474,17 @@ func validateOIDCAllowedUsers(
// on to registration.
func (h *Headscale) validateNodeForOIDCCallback(
writer http.ResponseWriter,
state string,
state types.RegistrationInfo,
claims *IDTokenClaims,
expiry time.Time,
) (*key.MachinePublic, bool, error) {
// retrieve nodekey from state cache
machineKeyIf, machineKeyFound := h.registrationCache.Get(state)
if !machineKeyFound {
log.Trace().
Msg("requested node state key expired before authorisation completed")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state has expired"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCNodeKeyMissing
}
var machineKey key.MachinePublic
machineKey, machineKeyOK := machineKeyIf.(key.MachinePublic)
if !machineKeyOK {
log.Trace().
Interface("got", machineKeyIf).
Msg("requested node state key is not a nodekey")
writer.Header().Set("Content-Type", "text/plain; charset=utf-8")
writer.WriteHeader(http.StatusBadRequest)
_, err := writer.Write([]byte("state is invalid"))
if err != nil {
util.LogErr(err, "Failed to write response")
}
return nil, false, errOIDCInvalidNodeState
}
// retrieve node information if it exist
// The error is not important, because if it does not
// exist, then this is a new node and we will move
// on to registration.
node, _ := h.db.GetNodeByMachineKey(machineKey)
node, _ := h.db.GetNodeByMachineKey(state.MachineKey)
if node != nil {
log.Trace().
@ -532,7 +536,7 @@ func (h *Headscale) validateNodeForOIDCCallback(
return nil, true, nil
}
return &machineKey, false, nil
return &state.MachineKey, false, nil
}
func getUserName(

View File

@ -306,11 +306,15 @@ func (node *Node) AfterFind(tx *gorm.DB) error {
}
node.NodeKey = nodeKey
var discoKey key.DiscoPublic
if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil {
return fmt.Errorf("unmarshalling disco key from db: %w", err)
// DiscoKey might be empty if a node has not sent it to headscale.
// This means that this might fail if the disco key is empty.
if node.DiscoKeyDatabaseField != "" {
var discoKey key.DiscoPublic
if err := discoKey.UnmarshalText([]byte(node.DiscoKeyDatabaseField)); err != nil {
return fmt.Errorf("unmarshalling disco key from db: %w", err)
}
node.DiscoKey = discoKey
}
node.DiscoKey = discoKey
endpoints := make([]netip.AddrPort, len(node.EndpointsDatabaseField))
for idx, ep := range node.EndpointsDatabaseField {
@ -530,3 +534,9 @@ func (nodes Nodes) IDMap() map[NodeID]*Node {
return ret
}
// RegistrationInfo contains both machine key and verifier information for OIDC validation.
type RegistrationInfo struct {
MachineKey key.MachinePublic
Verifier string
}

View File

@ -388,6 +388,101 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
assert.Len(t, listedPreAuthKeys, 3)
}
func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user1 := "user1"
user2 := "user2"
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
spec := map[string]int{
user1: 1,
user2: 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
var user2Key v1.PreAuthKey
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"preauthkeys",
"--user",
user2,
"create",
"--reusable",
"--expiration",
"24h",
"--output",
"json",
"--tags",
"tag:test1,tag:test2",
},
&user2Key,
)
assertNoErr(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
assert.Len(t, allClients, 1)
client := allClients[0]
// Log out from user1
err = client.Logout()
assertNoErr(t, err)
err = scenario.WaitForTailscaleLogout()
assertNoErr(t, err)
status, err := client.Status()
assertNoErr(t, err)
if status.BackendState == "Starting" || status.BackendState == "Running" {
t.Fatalf("expected node to be logged out, backend state: %s", status.BackendState)
}
err = client.Login(headscale.GetEndpoint(), user2Key.GetKey())
assertNoErr(t, err)
status, err = client.Status()
assertNoErr(t, err)
if status.BackendState != "Running" {
t.Fatalf("expected node to be logged in, backend state: %s", status.BackendState)
}
if status.Self.UserID.String() != "userid:2" {
t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String())
}
var listNodes []v1.Node
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
},
&listNodes,
)
assert.Nil(t, err)
assert.Len(t, listNodes, 1)
assert.Equal(t, "user2", listNodes[0].User.Name)
}
func TestApiKeyCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()