Compare commits

...

7 Commits

Author SHA1 Message Date
Julian cbbf71b759
Merge c6593a2bcb into 622aa82da2 2024-05-03 08:44:28 +05:30
Kristoffer Dalby 622aa82da2
ensure expire routines are cleaned up (#1924)
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-05-02 15:57:53 +00:00
Kristoffer Dalby a9c568c801
trace log and notifier shutdown (#1922)
Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-05-02 13:39:19 +02:00
Kristoffer Dalby 1c6bfc503c
fix preauth key logging in as previous user (#1920)
* add test case to reproduce #1885

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* fix preauth key issue logging in as wrong user

Fixes #1885

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* add test to gh

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
2024-05-02 11:53:16 +02:00
Julian Mundhahs c6593a2bcb feat: add a message to the user that the session is recorded
Signed-off-by: Julian Mundhahs <julian.mundhahs@mailbox.org>
2024-03-10 14:43:31 +01:00
Julian Mundhahs 31ee9bb4a8 test: add tests for `recorder` and `enforceRecorder` SSH ACL rules
Signed-off-by: Julian Mundhahs <github@mundhahs.dev>
2024-03-10 14:43:26 +01:00
Julian Mundhahs 04c7c5b0b6 feat: add `recorder` and `enforceRecorder` support in SSH ACL
Signed-off-by: Julian Mundhahs <github@mundhahs.dev>
2024-03-10 14:24:17 +01:00
8 changed files with 338 additions and 59 deletions

View File

@ -26,6 +26,7 @@ jobs:
- TestPreAuthKeyCommand
- TestPreAuthKeyCommandWithoutExpiry
- TestPreAuthKeyCommandReusableEphemeral
- TestPreAuthKeyCorrectUserLoggedInCommand
- TestApiKeyCommand
- TestNodeTagCommand
- TestNodeAdvertiseTagNoACLCommand

View File

@ -70,7 +70,7 @@ var (
const (
AuthPrefix = "Bearer "
updateInterval = 5000
updateInterval = 5 * time.Second
privateKeyFileMode = 0o600
headscaleDirPerm = 0o700
@ -219,64 +219,75 @@ func (h *Headscale) redirect(w http.ResponseWriter, req *http.Request) {
// deleteExpireEphemeralNodes deletes ephemeral node records that have not been
// seen for longer than h.cfg.EphemeralNodeInactivityTimeout.
func (h *Headscale) deleteExpireEphemeralNodes(milliSeconds int64) {
ticker := time.NewTicker(time.Duration(milliSeconds) * time.Millisecond)
func (h *Headscale) deleteExpireEphemeralNodes(ctx context.Context, every time.Duration) {
ticker := time.NewTicker(every)
for range ticker.C {
var removed []types.NodeID
var changed []types.NodeID
if err := h.db.Write(func(tx *gorm.DB) error {
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
var removed []types.NodeID
var changed []types.NodeID
if err := h.db.Write(func(tx *gorm.DB) error {
removed, changed = db.DeleteExpiredEphemeralNodes(tx, h.cfg.EphemeralNodeInactivityTimeout)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
continue
}
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring ephemeral nodes")
continue
}
if removed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: removed,
})
}
if removed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerRemoved,
Removed: removed,
})
}
if changed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changed,
})
if changed != nil {
ctx := types.NotifyCtx(context.Background(), "expire-ephemeral", "na")
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{
Type: types.StatePeerChanged,
ChangeNodes: changed,
})
}
}
}
}
// expireExpiredMachines expires nodes that have an explicit expiry set
// expireExpiredNodes expires nodes that have an explicit expiry set
// after that expiry time has passed.
func (h *Headscale) expireExpiredMachines(intervalMs int64) {
interval := time.Duration(intervalMs) * time.Millisecond
ticker := time.NewTicker(interval)
func (h *Headscale) expireExpiredNodes(ctx context.Context, every time.Duration) {
ticker := time.NewTicker(every)
lastCheck := time.Unix(0, 0)
var update types.StateUpdate
var changed bool
for range ticker.C {
if err := h.db.Write(func(tx *gorm.DB) error {
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
for {
select {
case <-ctx.Done():
ticker.Stop()
return
case <-ticker.C:
if err := h.db.Write(func(tx *gorm.DB) error {
lastCheck, update, changed = db.ExpireExpiredNodes(tx, lastCheck)
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring nodes")
continue
}
return nil
}); err != nil {
log.Error().Err(err).Msg("database error while expiring nodes")
continue
}
if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
if changed {
log.Trace().Interface("nodes", update.ChangePatches).Msgf("expiring nodes")
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
ctx := types.NotifyCtx(context.Background(), "expire-expired", "na")
h.nodeNotifier.NotifyAll(ctx, update)
}
}
}
}
@ -538,10 +549,13 @@ func (h *Headscale) Serve() error {
return errEmptyInitialDERPMap
}
// TODO(kradalby): These should have cancel channels and be cleaned
// up on shutdown.
go h.deleteExpireEphemeralNodes(updateInterval)
go h.expireExpiredMachines(updateInterval)
expireEphemeralCtx, expireEphemeralCancel := context.WithCancel(context.Background())
defer expireEphemeralCancel()
go h.deleteExpireEphemeralNodes(expireEphemeralCtx, updateInterval)
expireNodeCtx, expireNodeCancel := context.WithCancel(context.Background())
defer expireNodeCancel()
go h.expireExpiredNodes(expireNodeCtx, updateInterval)
if zl.GlobalLevel() == zl.TraceLevel {
zerolog.RespLog = true
@ -800,10 +814,26 @@ func (h *Headscale) Serve() error {
}
default:
trace := log.Trace().Msgf
log.Info().
Str("signal", sig.String()).
Msg("Received signal to stop, shutting down gracefully")
expireNodeCancel()
expireEphemeralCancel()
trace("closing map sessions")
wg := sync.WaitGroup{}
for _, mapSess := range h.mapSessions {
wg.Add(1)
go func() {
mapSess.close()
wg.Done()
}()
}
wg.Wait()
trace("waiting for netmap stream to close")
h.pollNetMapStreamWG.Wait()
// Gracefully shut down servers
@ -811,32 +841,44 @@ func (h *Headscale) Serve() error {
context.Background(),
types.HTTPShutdownTimeout,
)
trace("shutting down debug http server")
if err := debugHTTPServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Failed to shutdown prometheus http")
}
trace("shutting down main http server")
if err := httpServer.Shutdown(ctx); err != nil {
log.Error().Err(err).Msg("Failed to shutdown http")
}
trace("shutting down grpc server (socket)")
grpcSocket.GracefulStop()
if grpcServer != nil {
trace("shutting down grpc server (external)")
grpcServer.GracefulStop()
grpcListener.Close()
}
if tailsqlContext != nil {
trace("shutting down tailsql")
tailsqlContext.Done()
}
trace("closing node notifier")
h.nodeNotifier.Close()
// Close network listeners
trace("closing network listeners")
debugHTTPListener.Close()
httpListener.Close()
grpcGatewayConn.Close()
// Stop listening (and unlink the socket if unix type):
trace("closing socket listener")
socketListener.Close()
// Close db connections
trace("closing database connection")
err = h.db.Close()
if err != nil {
log.Error().Err(err).Msg("Failed to close db")

View File

@ -315,13 +315,16 @@ func (h *Headscale) handleAuthKey(
node.NodeKey = nodeKey
node.AuthKeyID = uint(pak.ID)
err := h.db.NodeSetExpiry(node.ID, registerRequest.Expiry)
node.Expiry = &registerRequest.Expiry
node.User = pak.User
node.UserID = pak.UserID
err := h.db.DB.Save(node).Error
if err != nil {
log.Error().
Caller().
Str("node", node.Hostname).
Err(err).
Msg("Failed to refresh node")
Msg("failed to save node after logging in with auth key")
return
}
@ -344,7 +347,7 @@ func (h *Headscale) handleAuthKey(
}
ctx := types.NotifyCtx(context.Background(), "handle-authkey", "na")
h.nodeNotifier.NotifyWithIgnore(ctx, types.StateUpdateExpire(node.ID, registerRequest.Expiry), node.ID)
h.nodeNotifier.NotifyAll(ctx, types.StateUpdate{Type: types.StatePeerChanged, ChangeNodes: []types.NodeID{node.ID}})
} else {
now := time.Now().UTC()

View File

@ -34,6 +34,11 @@ func NewNotifier(cfg *types.Config) *Notifier {
return n
}
// Close stops the batcher inside the notifier.
func (n *Notifier) Close() {
n.b.close()
}
func (n *Notifier) AddNode(nodeID types.NodeID, c chan<- types.StateUpdate) {
log.Trace().Caller().Uint64("node.id", nodeID.Uint64()).Msg("acquiring lock to add node")
defer log.Trace().

View File

@ -357,6 +357,31 @@ func (pol *ACLPolicy) CompileSSHPolicy(
return nil, fmt.Errorf("parsing SSH policy, unknown action %q, index: %d: %w", sshACL.Action, index, err)
}
recs := make([]netip.AddrPort, 0)
for innerIndex, rec := range sshACL.Recorder {
recSet, err := pol.ExpandAlias(append(peers, node), rec)
if err != nil {
log.Error().
Msgf("Error parsing SSH %d, Recorder %d", index, innerIndex)
return nil, err
}
// ExpandAlias has expanded possible subnets; all prefixes are single IPs
for _, rec := range recSet.Prefixes() {
recs = append(recs, netip.AddrPortFrom(rec.Addr(), 80))
}
}
if len(recs) > 0 {
action.Message = "# This session is being recorded.\n"
}
action.Recorders = recs
if sshACL.EnforceRecorder {
action.OnRecordingFailure = &tailcfg.SSHRecorderFailureAction{
RejectSessionWithMessage: "# Failed to start session recording.",
TerminateSessionWithMessage: "# Failed to start session recording.",
}
}
principals := make([]*tailcfg.SSHPrincipal, 0, len(sshACL.Sources))
for innerIndex, rawSrc := range sshACL.Sources {
if isWildcard(rawSrc) {

View File

@ -6,6 +6,7 @@ import (
"testing"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"github.com/rs/zerolog/log"
@ -3340,7 +3341,11 @@ func TestSSHRules(t *testing.T) {
SSHUsers: map[string]string{
"autogroup:nonroot": "=",
},
Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true},
Action: &tailcfg.SSHAction{
Accept: true,
AllowLocalPortForwarding: true,
Recorders: []netip.AddrPort{},
},
},
{
SSHUsers: map[string]string{
@ -3351,7 +3356,11 @@ func TestSSHRules(t *testing.T) {
Any: true,
},
},
Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true},
Action: &tailcfg.SSHAction{
Accept: true,
AllowLocalPortForwarding: true,
Recorders: []netip.AddrPort{},
},
},
{
Principals: []*tailcfg.SSHPrincipal{
@ -3362,7 +3371,11 @@ func TestSSHRules(t *testing.T) {
SSHUsers: map[string]string{
"autogroup:nonroot": "=",
},
Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true},
Action: &tailcfg.SSHAction{
Accept: true,
AllowLocalPortForwarding: true,
Recorders: []netip.AddrPort{},
},
},
{
SSHUsers: map[string]string{
@ -3373,7 +3386,100 @@ func TestSSHRules(t *testing.T) {
Any: true,
},
},
Action: &tailcfg.SSHAction{Accept: true, AllowLocalPortForwarding: true},
Action: &tailcfg.SSHAction{
Accept: true,
AllowLocalPortForwarding: true,
Recorders: []netip.AddrPort{},
},
},
},
},
{
name: "ssh-session-recording",
node: types.Node{
Hostname: "testnodes",
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.99.42")},
UserID: 0,
User: types.User{
Name: "user1",
},
},
peers: types.Nodes{
&types.Node{
Hostname: "testnodes2",
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.1")},
UserID: 0,
User: types.User{
Name: "user1",
},
},
},
pol: ACLPolicy{
Groups: Groups{
"group:test": []string{"user1"},
},
Hosts: Hosts{
"client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32),
},
ACLs: []ACL{
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"*:*"},
},
},
SSHs: []SSH{
{
Action: "accept",
Sources: []string{"group:test"},
Destinations: []string{"client"},
Users: []string{"autogroup:nonroot"},
Recorder: []string{"group:test"},
},
{
Action: "accept",
Sources: []string{"*"},
Destinations: []string{"client"},
Users: []string{"autogroup:nonroot"},
Recorder: []string{"client"},
EnforceRecorder: true,
},
},
},
want: []*tailcfg.SSHRule{
{
Principals: []*tailcfg.SSHPrincipal{
{
UserLogin: "user1",
},
},
SSHUsers: map[string]string{
"autogroup:nonroot": "=",
},
Action: &tailcfg.SSHAction{
Accept: true,
AllowLocalPortForwarding: true,
Recorders: []netip.AddrPort{netip.MustParseAddrPort("100.64.0.1:80"), netip.MustParseAddrPort("100.64.99.42:80")},
},
},
{
SSHUsers: map[string]string{
"autogroup:nonroot": "=",
},
Principals: []*tailcfg.SSHPrincipal{
{
Any: true,
},
},
Action: &tailcfg.SSHAction{
Accept: true,
AllowLocalPortForwarding: true,
Recorders: []netip.AddrPort{netip.MustParseAddrPort("100.64.99.42:80")},
OnRecordingFailure: &tailcfg.SSHRecorderFailureAction{
RejectSessionWithMessage: "# Failed to start session recording.",
TerminateSessionWithMessage: "# Failed to start session recording.",
},
},
},
}},
},
@ -3435,7 +3541,7 @@ func TestSSHRules(t *testing.T) {
got, err := tt.pol.CompileSSHPolicy(&tt.node, tt.peers)
assert.NoError(t, err)
if diff := cmp.Diff(tt.want, got); diff != "" {
if diff := cmp.Diff(tt.want, got, cmpopts.EquateComparable(netip.AddrPort{})); diff != "" {
t.Errorf("TestSSHRules() unexpected result (-want +got):\n%s", diff)
}
})

View File

@ -53,11 +53,13 @@ type AutoApprovers struct {
// SSH controls who can ssh into which machines.
type SSH struct {
Action string `json:"action" yaml:"action"`
Sources []string `json:"src" yaml:"src"`
Destinations []string `json:"dst" yaml:"dst"`
Users []string `json:"users" yaml:"users"`
CheckPeriod string `json:"checkPeriod,omitempty" yaml:"checkPeriod,omitempty"`
Action string `json:"action" yaml:"action"`
Sources []string `json:"src" yaml:"src"`
Destinations []string `json:"dst" yaml:"dst"`
Users []string `json:"users" yaml:"users"`
CheckPeriod string `json:"checkPeriod,omitempty" yaml:"checkPeriod,omitempty"`
Recorder []string `json:"recorder,omitempty" yaml:"recorder,omitempty"`
EnforceRecorder bool `json:"enforceRecorder,omitempty" yaml:"enforceRecorder,omitempty"`
}
// UnmarshalJSON allows to parse the Hosts directly into netip objects.

View File

@ -388,6 +388,101 @@ func TestPreAuthKeyCommandReusableEphemeral(t *testing.T) {
assert.Len(t, listedPreAuthKeys, 3)
}
func TestPreAuthKeyCorrectUserLoggedInCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()
user1 := "user1"
user2 := "user2"
scenario, err := NewScenario(dockertestMaxWait())
assertNoErr(t, err)
defer scenario.Shutdown()
spec := map[string]int{
user1: 1,
user2: 0,
}
err = scenario.CreateHeadscaleEnv(spec, []tsic.Option{}, hsic.WithTestName("clipak"))
assertNoErr(t, err)
headscale, err := scenario.Headscale()
assertNoErr(t, err)
var user2Key v1.PreAuthKey
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"preauthkeys",
"--user",
user2,
"create",
"--reusable",
"--expiration",
"24h",
"--output",
"json",
"--tags",
"tag:test1,tag:test2",
},
&user2Key,
)
assertNoErr(t, err)
allClients, err := scenario.ListTailscaleClients()
assertNoErrListClients(t, err)
assert.Len(t, allClients, 1)
client := allClients[0]
// Log out from user1
err = client.Logout()
assertNoErr(t, err)
err = scenario.WaitForTailscaleLogout()
assertNoErr(t, err)
status, err := client.Status()
assertNoErr(t, err)
if status.BackendState == "Starting" || status.BackendState == "Running" {
t.Fatalf("expected node to be logged out, backend state: %s", status.BackendState)
}
err = client.Login(headscale.GetEndpoint(), user2Key.GetKey())
assertNoErr(t, err)
status, err = client.Status()
assertNoErr(t, err)
if status.BackendState != "Running" {
t.Fatalf("expected node to be logged in, backend state: %s", status.BackendState)
}
if status.Self.UserID.String() != "userid:2" {
t.Fatalf("expected node to be logged in as userid:2, got: %s", status.Self.UserID.String())
}
var listNodes []v1.Node
err = executeAndUnmarshal(
headscale,
[]string{
"headscale",
"nodes",
"list",
"--output",
"json",
},
&listNodes,
)
assert.Nil(t, err)
assert.Len(t, listNodes, 1)
assert.Equal(t, "user2", listNodes[0].User.Name)
}
func TestApiKeyCommand(t *testing.T) {
IntegrationSkip(t)
t.Parallel()