move to use tailscfg types over strings/custom types (#1612)

* rename database only fields

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* use correct endpoint type over string list

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* remove HostInfo wrapper

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

* wrap errors in database hooks

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>

---------

Signed-off-by: Kristoffer Dalby <kristoffer@tailscale.com>
This commit is contained in:
Kristoffer Dalby 2023-11-21 18:20:06 +01:00 committed by GitHub
parent ed4e19996b
commit b918aa03fc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 147 additions and 154 deletions

View File

@ -12,6 +12,7 @@ import (
"github.com/juanfont/headscale/hscontrol/types"
"github.com/juanfont/headscale/hscontrol/util"
"gopkg.in/check.v1"
"tailscale.com/tailcfg"
"tailscale.com/types/key"
)
@ -593,7 +594,7 @@ func (s *Suite) TestAutoApproveRoutes(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:exit"},
RoutableIPs: []netip.Prefix{defaultRouteV4, defaultRouteV6, route1, route2},
},

View File

@ -274,7 +274,7 @@ func (hsdb *HSDatabase) saveNodeRoutes(node *types.Node) error {
}
advertisedRoutes := map[netip.Prefix]bool{}
for _, prefix := range node.HostInfo.RoutableIPs {
for _, prefix := range node.Hostinfo.RoutableIPs {
advertisedRoutes[prefix] = false
}

View File

@ -33,7 +33,7 @@ func (s *Suite) TestGetRoutes(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
Hostinfo: &hostInfo,
}
db.db.Save(&node)
@ -81,7 +81,7 @@ func (s *Suite) TestGetEnableRoutes(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo),
Hostinfo: &hostInfo,
}
db.db.Save(&node)
@ -152,7 +152,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo1),
Hostinfo: &hostInfo1,
}
db.db.Save(&node1)
@ -174,7 +174,7 @@ func (s *Suite) TestIsUniquePrefix(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo2),
Hostinfo: &hostInfo2,
}
db.db.Save(&node2)
@ -232,7 +232,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo1),
Hostinfo: &hostInfo1,
LastSeen: &now,
}
db.db.Save(&node1)
@ -266,7 +266,7 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo2),
Hostinfo: &hostInfo2,
LastSeen: &now,
}
db.db.Save(&node2)
@ -313,9 +313,9 @@ func (s *Suite) TestSubnetFailover(c *check.C) {
c.Assert(err, check.IsNil)
c.Assert(len(routes), check.Equals, 1)
node2.HostInfo = types.HostInfo(tailcfg.Hostinfo{
node2.Hostinfo = &tailcfg.Hostinfo{
RoutableIPs: []netip.Prefix{prefix, prefix2},
})
}
err = db.db.Save(&node2).Error
c.Assert(err, check.IsNil)
@ -368,7 +368,7 @@ func (s *Suite) TestDeleteRoutes(c *check.C) {
UserID: user.ID,
RegisterMethod: util.RegisterMethodAuthKey,
AuthKeyID: uint(pak.ID),
HostInfo: types.HostInfo(hostInfo1),
Hostinfo: &hostInfo1,
LastSeen: &now,
}
db.db.Save(&node1)

View File

@ -550,7 +550,7 @@ func (api headscaleV1APIServer) DebugCreateNode(
Expiry: &time.Time{},
LastSeen: &time.Time{},
HostInfo: types.HostInfo(hostinfo),
Hostinfo: &hostinfo,
}
log.Debug().

View File

@ -195,7 +195,7 @@ func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {
attrs := url.Values{
"device_name": []string{node.Hostname},
"device_model": []string{node.HostInfo.OS},
"device_model": []string{node.Hostinfo.OS},
}
if len(node.IPAddresses) > 0 {

View File

@ -186,8 +186,7 @@ func Test_fullMapResponse(t *testing.T) {
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
HostInfo: types.HostInfo{},
Endpoints: []string{},
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{
{
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),
@ -267,8 +266,7 @@ func Test_fullMapResponse(t *testing.T) {
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
HostInfo: types.HostInfo{},
Endpoints: []string{},
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{},
CreatedAt: created,
}
@ -324,8 +322,7 @@ func Test_fullMapResponse(t *testing.T) {
ForcedTags: []string{},
LastSeen: &lastSeen,
Expiry: &expire,
HostInfo: types.HostInfo{},
Endpoints: []string{},
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{},
CreatedAt: created,
}

View File

@ -72,8 +72,8 @@ func tailNode(
}
var derp string
if node.HostInfo.NetInfo != nil {
derp = fmt.Sprintf("127.3.3.40:%d", node.HostInfo.NetInfo.PreferredDERP)
if node.Hostinfo.NetInfo != nil {
derp = fmt.Sprintf("127.3.3.40:%d", node.Hostinfo.NetInfo.PreferredDERP)
} else {
derp = "127.3.3.40:0" // Zero means disconnected or unknown.
}
@ -90,18 +90,11 @@ func tailNode(
return nil, err
}
hostInfo := node.GetHostInfo()
online := node.IsOnline()
tags, _ := pol.TagsOfNode(node)
tags = lo.Uniq(append(tags, node.ForcedTags...))
endpoints, err := node.EndpointsToAddrPort()
if err != nil {
return nil, err
}
tNode := tailcfg.Node{
ID: tailcfg.NodeID(node.ID), // this is the actual ID
StableID: tailcfg.StableNodeID(
@ -118,9 +111,9 @@ func tailNode(
DiscoKey: node.DiscoKey,
Addresses: addrs,
AllowedIPs: allowedIPs,
Endpoints: endpoints,
Endpoints: node.Endpoints,
DERP: derp,
Hostinfo: hostInfo.View(),
Hostinfo: node.Hostinfo.View(),
Created: node.CreatedAt,
Tags: tags,

View File

@ -53,8 +53,10 @@ func TestTailNode(t *testing.T) {
wantErr bool
}{
{
name: "empty-node",
node: &types.Node{},
name: "empty-node",
node: &types.Node{
Hostinfo: &tailcfg.Hostinfo{},
},
pol: &policy.ACLPolicy{},
dnsConfig: &tailcfg.DNSConfig{},
baseDomain: "",
@ -102,8 +104,7 @@ func TestTailNode(t *testing.T) {
AuthKey: &types.PreAuthKey{},
LastSeen: &lastSeen,
Expiry: &expire,
HostInfo: types.HostInfo{},
Endpoints: []string{},
Hostinfo: &tailcfg.Hostinfo{},
Routes: []types.Route{
{
Prefix: types.IPPrefix(netip.MustParsePrefix("0.0.0.0/0")),

View File

@ -596,10 +596,13 @@ func excludeCorrectlyTaggedNodes(
}
// for each node if tag is in tags list, don't append it.
for _, node := range nodes {
hi := node.GetHostInfo()
found := false
for _, t := range hi.RequestTags {
if node.Hostinfo == nil {
continue
}
for _, t := range node.Hostinfo.RequestTags {
if util.StringOrPrefixListContains(tags, t) {
found = true
@ -787,8 +790,11 @@ func (pol *ACLPolicy) expandIPsFromTag(
for _, user := range owners {
nodes := filterNodesByUser(nodes, user)
for _, node := range nodes {
hi := node.GetHostInfo()
if util.StringOrPrefixListContains(hi.RequestTags, alias) {
if node.Hostinfo == nil {
continue
}
if util.StringOrPrefixListContains(node.Hostinfo.RequestTags, alias) {
node.IPAddresses.AppendToIPSet(&build)
}
}
@ -882,7 +888,7 @@ func (pol *ACLPolicy) TagsOfNode(
validTagMap := make(map[string]bool)
invalidTagMap := make(map[string]bool)
for _, tag := range node.HostInfo.RequestTags {
for _, tag := range node.Hostinfo.RequestTags {
owners, err := expandOwnersFromTag(pol, tag)
if errors.Is(err, ErrInvalidTag) {
invalidTagMap[tag] = true

View File

@ -418,6 +418,7 @@ acls:
User: types.User{
Name: "testuser",
},
Hostinfo: &tailcfg.Hostinfo{},
},
})
@ -1264,7 +1265,7 @@ func Test_expandAlias(t *testing.T) {
netip.MustParseAddr("100.64.0.1"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:hr-webserver"},
@ -1275,7 +1276,7 @@ func Test_expandAlias(t *testing.T) {
netip.MustParseAddr("100.64.0.2"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:hr-webserver"},
@ -1405,7 +1406,7 @@ func Test_expandAlias(t *testing.T) {
netip.MustParseAddr("100.64.0.2"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:hr-webserver"},
@ -1443,7 +1444,7 @@ func Test_expandAlias(t *testing.T) {
netip.MustParseAddr("100.64.0.1"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:accountant-webserver"},
@ -1454,7 +1455,7 @@ func Test_expandAlias(t *testing.T) {
netip.MustParseAddr("100.64.0.2"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:accountant-webserver"},
@ -1464,13 +1465,15 @@ func Test_expandAlias(t *testing.T) {
IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.3"),
},
User: types.User{Name: "marc"},
User: types.User{Name: "marc"},
Hostinfo: &tailcfg.Hostinfo{},
},
&types.Node{
IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.4"),
},
User: types.User{Name: "joe"},
User: types.User{Name: "joe"},
Hostinfo: &tailcfg.Hostinfo{},
},
},
},
@ -1520,7 +1523,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
netip.MustParseAddr("100.64.0.1"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:accountant-webserver"},
@ -1531,7 +1534,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
netip.MustParseAddr("100.64.0.2"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:accountant-webserver"},
@ -1541,7 +1544,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.4"),
},
User: types.User{Name: "joe"},
User: types.User{Name: "joe"},
Hostinfo: &tailcfg.Hostinfo{},
},
},
user: "joe",
@ -1550,6 +1554,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
&types.Node{
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")},
User: types.User{Name: "joe"},
Hostinfo: &tailcfg.Hostinfo{},
},
},
},
@ -1570,7 +1575,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
netip.MustParseAddr("100.64.0.1"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:accountant-webserver"},
@ -1581,7 +1586,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
netip.MustParseAddr("100.64.0.2"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:accountant-webserver"},
@ -1591,7 +1596,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.4"),
},
User: types.User{Name: "joe"},
User: types.User{Name: "joe"},
Hostinfo: &tailcfg.Hostinfo{},
},
},
user: "joe",
@ -1600,6 +1606,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
&types.Node{
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")},
User: types.User{Name: "joe"},
Hostinfo: &tailcfg.Hostinfo{},
},
},
},
@ -1615,7 +1622,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
netip.MustParseAddr("100.64.0.1"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "foo",
RequestTags: []string{"tag:accountant-webserver"},
@ -1627,12 +1634,14 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
},
User: types.User{Name: "joe"},
ForcedTags: []string{"tag:accountant-webserver"},
Hostinfo: &tailcfg.Hostinfo{},
},
&types.Node{
IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.4"),
},
User: types.User{Name: "joe"},
User: types.User{Name: "joe"},
Hostinfo: &tailcfg.Hostinfo{},
},
},
user: "joe",
@ -1641,6 +1650,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
&types.Node{
IPAddresses: types.NodeAddresses{netip.MustParseAddr("100.64.0.4")},
User: types.User{Name: "joe"},
Hostinfo: &tailcfg.Hostinfo{},
},
},
},
@ -1656,7 +1666,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
netip.MustParseAddr("100.64.0.1"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "hr-web1",
RequestTags: []string{"tag:hr-webserver"},
@ -1667,7 +1677,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
netip.MustParseAddr("100.64.0.2"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "hr-web2",
RequestTags: []string{"tag:hr-webserver"},
@ -1677,7 +1687,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.4"),
},
User: types.User{Name: "joe"},
User: types.User{Name: "joe"},
Hostinfo: &tailcfg.Hostinfo{},
},
},
user: "joe",
@ -1688,7 +1699,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
netip.MustParseAddr("100.64.0.1"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "hr-web1",
RequestTags: []string{"tag:hr-webserver"},
@ -1699,7 +1710,7 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
netip.MustParseAddr("100.64.0.2"),
},
User: types.User{Name: "joe"},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
OS: "centos",
Hostname: "hr-web2",
RequestTags: []string{"tag:hr-webserver"},
@ -1709,7 +1720,8 @@ func Test_excludeCorrectlyTaggedNodes(t *testing.T) {
IPAddresses: types.NodeAddresses{
netip.MustParseAddr("100.64.0.4"),
},
User: types.User{Name: "joe"},
User: types.User{Name: "joe"},
Hostinfo: &tailcfg.Hostinfo{},
},
},
},
@ -1952,7 +1964,7 @@ func Test_getTags(t *testing.T) {
User: types.User{
Name: "joe",
},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:valid"},
},
},
@ -1972,7 +1984,7 @@ func Test_getTags(t *testing.T) {
User: types.User{
Name: "joe",
},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:valid", "tag:invalid"},
},
},
@ -1992,7 +2004,7 @@ func Test_getTags(t *testing.T) {
User: types.User{
Name: "joe",
},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{
"tag:invalid",
"tag:valid",
@ -2016,7 +2028,7 @@ func Test_getTags(t *testing.T) {
User: types.User{
Name: "joe",
},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:invalid", "very-invalid"},
},
},
@ -2032,7 +2044,7 @@ func Test_getTags(t *testing.T) {
User: types.User{
Name: "joe",
},
HostInfo: types.HostInfo{
Hostinfo: &tailcfg.Hostinfo{
RequestTags: []string{"tag:invalid", "very-invalid"},
},
},
@ -3010,7 +3022,7 @@ func TestValidExpandTagOwnersInSources(t *testing.T) {
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
HostInfo: types.HostInfo(hostInfo),
Hostinfo: &hostInfo,
}
pol := &ACLPolicy{
@ -3062,7 +3074,7 @@ func TestInvalidTagValidUser(t *testing.T) {
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
HostInfo: types.HostInfo(hostInfo),
Hostinfo: &hostInfo,
}
pol := &ACLPolicy{
@ -3113,7 +3125,7 @@ func TestValidExpandTagOwnersInDestinations(t *testing.T) {
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
HostInfo: types.HostInfo(hostInfo),
Hostinfo: &hostInfo,
}
pol := &ACLPolicy{
@ -3174,7 +3186,7 @@ func TestValidTagInvalidUser(t *testing.T) {
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
HostInfo: types.HostInfo(hostInfo),
Hostinfo: &hostInfo,
}
hostInfo2 := tailcfg.Hostinfo{
@ -3191,7 +3203,7 @@ func TestValidTagInvalidUser(t *testing.T) {
Name: "user1",
},
RegisterMethod: util.RegisterMethodAuthKey,
HostInfo: types.HostInfo(hostInfo2),
Hostinfo: &hostInfo2,
}
pol := &ACLPolicy{

View File

@ -83,15 +83,14 @@ func (h *Headscale) handlePoll(
Bool("stream", mapRequest.Stream).
Str("node_key", node.NodeKey.ShortString()).
Str("node", node.Hostname).
Strs("endpoints", node.Endpoints).
Msg("Received endpoint update")
now := time.Now().UTC()
node.LastSeen = &now
node.Hostname = mapRequest.Hostinfo.Hostname
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
node.Hostinfo = mapRequest.Hostinfo
node.DiscoKey = mapRequest.DiscoKey
node.SetEndpointsFromAddrPorts(mapRequest.Endpoints)
node.Endpoints = mapRequest.Endpoints
if err := h.db.NodeSave(node); err != nil {
logErr(err, "Failed to persist/update node in the database")
@ -142,9 +141,9 @@ func (h *Headscale) handlePoll(
now := time.Now().UTC()
node.LastSeen = &now
node.Hostname = mapRequest.Hostinfo.Hostname
node.HostInfo = types.HostInfo(*mapRequest.Hostinfo)
node.Hostinfo = mapRequest.Hostinfo
node.DiscoKey = mapRequest.DiscoKey
node.SetEndpointsFromAddrPorts(mapRequest.Endpoints)
node.Endpoints = mapRequest.Endpoints
// When a node connects to control, list the peers it has at
// that given point, further updates are kept in memory in

View File

@ -12,33 +12,6 @@ import (
var ErrCannotParsePrefix = errors.New("cannot parse prefix")
// This is a "wrapper" type around tailscales
// Hostinfo to allow us to add database "serialization"
// methods. This allows us to use a typed values throughout
// the code and not have to marshal/unmarshal and error
// check all over the code.
type HostInfo tailcfg.Hostinfo
func (hi *HostInfo) Scan(destination interface{}) error {
switch value := destination.(type) {
case []byte:
return json.Unmarshal(value, hi)
case string:
return json.Unmarshal([]byte(value), hi)
default:
return fmt.Errorf("%w: unexpected data type %T", ErrNodeAddressesInvalid, destination)
}
}
// Value return json value, implement driver.Valuer interface.
func (hi HostInfo) Value() (driver.Value, error) {
bytes, err := json.Marshal(hi)
return string(bytes), err
}
type IPPrefix netip.Prefix
func (i *IPPrefix) Scan(destination interface{}) error {

View File

@ -2,6 +2,7 @@ package types
import (
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"net/netip"
@ -27,27 +28,40 @@ var (
type Node struct {
ID uint64 `gorm:"primary_key"`
// MachineKeyValue is the string representation of MachineKey
// MachineKeyDatabaseField is the string representation of MachineKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use MachineKey instead.
MachineKeyValue string `gorm:"column:machine_key;unique_index"`
MachineKeyDatabaseField string `gorm:"column:machine_key;unique_index"`
MachineKey key.MachinePublic `gorm:"-"`
// NodeKeyValue is the string representation of NodeKey
// NodeKeyDatabaseField is the string representation of NodeKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use NodeKey instead.
NodeKeyValue string `gorm:"column:node_key"`
NodeKeyDatabaseField string `gorm:"column:node_key"`
NodeKey key.NodePublic `gorm:"-"`
// DiscoKeyValue is the string representation of DiscoKey
// DiscoKeyDatabaseField is the string representation of DiscoKey
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use DiscoKey instead.
DiscoKeyValue string `gorm:"column:disco_key"`
DiscoKeyDatabaseField string `gorm:"column:disco_key"`
DiscoKey key.DiscoPublic `gorm:"-"`
MachineKey key.MachinePublic `gorm:"-"`
NodeKey key.NodePublic `gorm:"-"`
DiscoKey key.DiscoPublic `gorm:"-"`
// EndpointsDatabaseField is the string list representation of Endpoints
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use Endpoints instead.
EndpointsDatabaseField StringList `gorm:"column:endpoints"`
Endpoints []netip.AddrPort `gorm:"-"`
// EndpointsDatabaseField is the string list representation of Endpoints
// it is _only_ used for reading and writing the key to the
// database and should not be used.
// Use Endpoints instead.
HostinfoDatabaseField string `gorm:"column:hostinfo"`
Hostinfo *tailcfg.Hostinfo `gorm:"-"`
IPAddresses NodeAddresses
@ -76,9 +90,6 @@ type Node struct {
LastSeen *time.Time
Expiry *time.Time
HostInfo HostInfo
Endpoints StringList
Routes []Route
CreatedAt time.Time
@ -195,31 +206,6 @@ func (node Node) IsExpired() bool {
return time.Now().UTC().After(*node.Expiry)
}
// TODO(kradalby): Try to replace the types in the DB to be correct.
func (node *Node) EndpointsToAddrPort() ([]netip.AddrPort, error) {
var ret []netip.AddrPort
for _, ep := range node.Endpoints {
addrPort, err := netip.ParseAddrPort(ep)
if err != nil {
return nil, err
}
ret = append(ret, addrPort)
}
return ret, nil
}
// TODO(kradalby): Try to replace the types in the DB to be correct.
func (node *Node) SetEndpointsFromAddrPorts(in []netip.AddrPort) {
var strs StringList
for _, addrPort := range in {
strs = append(strs, addrPort.String())
}
node.Endpoints = strs
}
// IsOnline returns if the node is connected to Headscale.
// This is really a naive implementation, as we don't really see
// if there is a working connection between the client and the server.
@ -277,9 +263,22 @@ func (nodes Nodes) FilterByIP(ip netip.Addr) Nodes {
// correctly in the database.
// This currently means storing the keys as strings.
func (n *Node) BeforeSave(tx *gorm.DB) (err error) {
n.MachineKeyValue = n.MachineKey.String()
n.NodeKeyValue = n.NodeKey.String()
n.DiscoKeyValue = n.DiscoKey.String()
n.MachineKeyDatabaseField = n.MachineKey.String()
n.NodeKeyDatabaseField = n.NodeKey.String()
n.DiscoKeyDatabaseField = n.DiscoKey.String()
var endpoints StringList
for _, addrPort := range n.Endpoints {
endpoints = append(endpoints, addrPort.String())
}
n.EndpointsDatabaseField = endpoints
hi, err := json.Marshal(n.Hostinfo)
if err != nil {
return fmt.Errorf("failed to marshal Hostinfo to store in db: %w", err)
}
n.HostinfoDatabaseField = string(hi)
return
}
@ -291,23 +290,40 @@ func (n *Node) BeforeSave(tx *gorm.DB) (err error) {
// the proper types.
func (n *Node) AfterFind(tx *gorm.DB) (err error) {
var machineKey key.MachinePublic
if err := machineKey.UnmarshalText([]byte(n.MachineKeyValue)); err != nil {
return err
if err := machineKey.UnmarshalText([]byte(n.MachineKeyDatabaseField)); err != nil {
return fmt.Errorf("failed to unmarshal machine key from db: %w", err)
}
n.MachineKey = machineKey
var nodeKey key.NodePublic
if err := nodeKey.UnmarshalText([]byte(n.NodeKeyValue)); err != nil {
return err
if err := nodeKey.UnmarshalText([]byte(n.NodeKeyDatabaseField)); err != nil {
return fmt.Errorf("failed to unmarshal node key from db: %w", err)
}
n.NodeKey = nodeKey
var discoKey key.DiscoPublic
if err := discoKey.UnmarshalText([]byte(n.DiscoKeyValue)); err != nil {
return err
if err := discoKey.UnmarshalText([]byte(n.DiscoKeyDatabaseField)); err != nil {
return fmt.Errorf("failed to unmarshal disco key from db: %w", err)
}
n.DiscoKey = discoKey
var endpoints []netip.AddrPort
for _, ep := range n.EndpointsDatabaseField {
addrPort, err := netip.ParseAddrPort(ep)
if err != nil {
return fmt.Errorf("failed to parse endpoint from db: %w", err)
}
endpoints = append(endpoints, addrPort)
}
n.Endpoints = endpoints
var hi tailcfg.Hostinfo
if err := json.Unmarshal([]byte(n.HostinfoDatabaseField), &hi); err != nil {
return fmt.Errorf("failed to unmarshal Hostinfo from db: %w", err)
}
n.Hostinfo = &hi
return
}
@ -346,11 +362,6 @@ func (node *Node) Proto() *v1.Node {
return nodeProto
}
// GetHostInfo returns a Hostinfo struct for the node.
func (node *Node) GetHostInfo() tailcfg.Hostinfo {
return tailcfg.Hostinfo(node.HostInfo)
}
func (node *Node) GetFQDN(dnsConfig *tailcfg.DNSConfig, baseDomain string) (string, error) {
var hostname string
if dnsConfig != nil && dnsConfig.Proxied { // MagicDNS