From ae46ce9937fcc3cbf65a4f2b88384f02af638d01 Mon Sep 17 00:00:00 2001 From: Aditya Manthramurthy Date: Thu, 18 Apr 2024 08:15:02 -0700 Subject: [PATCH] ldap: Normalize DNs when importing (#19528) This is a change to IAM export/import functionality. For LDAP enabled setups, it performs additional validations: - for policy mappings on LDAP users and groups, it ensures that the corresponding user or group DN exists and if so uses a normalized form of these DNs for storage - for access keys (service accounts), it updates (i.e. validates existence and normalizes) the internally stored parent user DN and group DNs. This allows for a migration path for setups in which LDAP mappings have been stored in previous versions of the server, where the name of the mapping file stored on drives is not in a normalized form. An administrator needs to execute: `mc admin iam export ALIAS` followed by `mc admin iam import ALIAS /path/to/export/file` The validations are more strict and returns errors when multiple mappings are found for the same user/group DN. This is to ensure the mappings stored by the server are unambiguous and to reduce the potential for confusion. Bonus **bug fix**: IAM export of access keys (service accounts) did not export key name, description and expiration. This is fixed in this change too. --- .typos.toml | 1 + cmd/admin-handlers-users.go | 88 ++++++++----- cmd/iam.go | 164 +++++++++++++++++++++++- cmd/sts-handlers_test.go | 176 ++++++++++++++++++++++++++ internal/config/identity/ldap/ldap.go | 101 +++++++-------- 5 files changed, 438 insertions(+), 92 deletions(-) diff --git a/.typos.toml b/.typos.toml index 38bdfb208..9f67ce391 100644 --- a/.typos.toml +++ b/.typos.toml @@ -15,6 +15,7 @@ extend-ignore-re = [ "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiJ9.*", "MIIDBTCCAe2gAwIBAgIQWHw7h.*", 'http\.Header\{"X-Amz-Server-Side-Encryptio":', + "ZoEoZdLlzVbOlT9rbhD7ZN7TLyiYXSAlB79uGEge", ] [default.extend-words] diff --git a/cmd/admin-handlers-users.go b/cmd/admin-handlers-users.go index 7a1b4595f..1d4ceb7e4 100644 --- a/cmd/admin-handlers-users.go +++ b/cmd/admin-handlers-users.go @@ -1763,9 +1763,20 @@ const ( userPolicyMappingsFile = "user_mappings.json" groupPolicyMappingsFile = "group_mappings.json" stsUserPolicyMappingsFile = "stsuser_mappings.json" - iamAssetsDir = "iam-assets" + + iamAssetsDir = "iam-assets" ) +var iamExportFiles = []string{ + allPoliciesFile, + allUsersFile, + allGroupsFile, + allSvcAcctsFile, + userPolicyMappingsFile, + groupPolicyMappingsFile, + stsUserPolicyMappingsFile, +} + // ExportIAMHandler - exports all iam info as a zipped file func (a adminAPIHandlers) ExportIAM(w http.ResponseWriter, r *http.Request) { ctx := r.Context() @@ -1804,16 +1815,7 @@ func (a adminAPIHandlers) ExportIAM(w http.ResponseWriter, r *http.Request) { return nil } - iamFiles := []string{ - allPoliciesFile, - allUsersFile, - allGroupsFile, - allSvcAcctsFile, - userPolicyMappingsFile, - groupPolicyMappingsFile, - stsUserPolicyMappingsFile, - } - for _, f := range iamFiles { + for _, f := range iamExportFiles { iamFile := pathJoin(iamAssetsDir, f) switch f { case allPoliciesFile: @@ -1898,7 +1900,7 @@ func (a adminAPIHandlers) ExportIAM(w http.ResponseWriter, r *http.Request) { writeErrorResponse(ctx, w, exportError(ctx, err, iamFile, ""), r.URL) return } - _, policy, err := globalIAMSys.GetServiceAccount(ctx, acc.Credentials.AccessKey) + sa, policy, err := globalIAMSys.GetServiceAccount(ctx, acc.Credentials.AccessKey) if err != nil { writeErrorResponse(ctx, w, exportError(ctx, err, iamFile, ""), r.URL) return @@ -1920,6 +1922,9 @@ func (a adminAPIHandlers) ExportIAM(w http.ResponseWriter, r *http.Request) { Claims: claims, SessionPolicy: json.RawMessage(policyJSON), Status: acc.Credentials.Status, + Name: sa.Name, + Description: sa.Description, + Expiration: &sa.Expiration, } } @@ -2184,6 +2189,16 @@ func (a adminAPIHandlers) ImportIAM(w http.ResponseWriter, r *http.Request) { writeErrorResponseJSON(ctx, w, importErrorWithAPIErr(ctx, ErrAdminConfigBadJSON, err, allSvcAcctsFile, ""), r.URL) return } + + // Validations for LDAP enabled deployments. + if globalIAMSys.LDAPConfig.Enabled() { + err := globalIAMSys.NormalizeLDAPAccessKeypairs(ctx, serviceAcctReqs) + if err != nil { + writeErrorResponseJSON(ctx, w, importError(ctx, err, allSvcAcctsFile, ""), r.URL) + return + } + } + for user, svcAcctReq := range serviceAcctReqs { var sp *policy.Policy var err error @@ -2220,20 +2235,14 @@ func (a adminAPIHandlers) ImportIAM(w http.ResponseWriter, r *http.Request) { updateReq = false } if updateReq { - opts := updateServiceAccountOpts{ - secretKey: svcAcctReq.SecretKey, - status: svcAcctReq.Status, - name: svcAcctReq.Name, - description: svcAcctReq.Description, - expiration: svcAcctReq.Expiration, - sessionPolicy: sp, - } - _, err = globalIAMSys.UpdateServiceAccount(ctx, svcAcctReq.AccessKey, opts) + // If the service account exists, we remove it to ensure a + // clean import. + err := globalIAMSys.DeleteServiceAccount(ctx, svcAcctReq.AccessKey, true) if err != nil { - writeErrorResponseJSON(ctx, w, importError(ctx, err, allSvcAcctsFile, user), r.URL) + delErr := fmt.Errorf("failed to delete existing service account(%s) before importing it: %w", svcAcctReq.AccessKey, err) + writeErrorResponseJSON(ctx, w, importError(ctx, delErr, allSvcAcctsFile, user), r.URL) return } - continue } opts := newServiceAccountOpts{ accessKey: user, @@ -2246,18 +2255,6 @@ func (a adminAPIHandlers) ImportIAM(w http.ResponseWriter, r *http.Request) { allowSiteReplicatorAccount: false, } - // In case of LDAP we need to resolve the targetUser to a DN and - // query their groups: - if globalIAMSys.LDAPConfig.Enabled() { - opts.claims[ldapUserN] = svcAcctReq.AccessKey // simple username - targetUser, _, err := globalIAMSys.LDAPConfig.LookupUserDN(svcAcctReq.AccessKey) - if err != nil { - writeErrorResponseJSON(ctx, w, importError(ctx, err, allSvcAcctsFile, user), r.URL) - return - } - opts.claims[ldapUser] = targetUser // username DN - } - if _, _, err = globalIAMSys.NewServiceAccount(ctx, svcAcctReq.Parent, svcAcctReq.Groups, opts); err != nil { writeErrorResponseJSON(ctx, w, importError(ctx, err, allSvcAcctsFile, user), r.URL) return @@ -2326,6 +2323,17 @@ func (a adminAPIHandlers) ImportIAM(w http.ResponseWriter, r *http.Request) { writeErrorResponseJSON(ctx, w, importErrorWithAPIErr(ctx, ErrAdminConfigBadJSON, err, groupPolicyMappingsFile, ""), r.URL) return } + + // Validations for LDAP enabled deployments. + if globalIAMSys.LDAPConfig.Enabled() { + isGroup := true + err := globalIAMSys.NormalizeLDAPMappingImport(ctx, isGroup, grpPolicyMap) + if err != nil { + writeErrorResponseJSON(ctx, w, importError(ctx, err, groupPolicyMappingsFile, ""), r.URL) + return + } + } + for g, pm := range grpPolicyMap { if _, err := globalIAMSys.PolicyDBSet(ctx, g, pm.Policies, unknownIAMUserType, true); err != nil { writeErrorResponseJSON(ctx, w, importError(ctx, err, groupPolicyMappingsFile, g), r.URL) @@ -2355,6 +2363,16 @@ func (a adminAPIHandlers) ImportIAM(w http.ResponseWriter, r *http.Request) { writeErrorResponseJSON(ctx, w, importErrorWithAPIErr(ctx, ErrAdminConfigBadJSON, err, stsUserPolicyMappingsFile, ""), r.URL) return } + + // Validations for LDAP enabled deployments. + if globalIAMSys.LDAPConfig.Enabled() { + isGroup := true + err := globalIAMSys.NormalizeLDAPMappingImport(ctx, !isGroup, userPolicyMap) + if err != nil { + writeErrorResponseJSON(ctx, w, importError(ctx, err, stsUserPolicyMappingsFile, ""), r.URL) + return + } + } for u, pm := range userPolicyMap { // disallow setting policy mapping if user is a temporary user ok, _, err := globalIAMSys.IsTempUser(u) diff --git a/cmd/iam.go b/cmd/iam.go index 1a666367b..67d4a6a58 100644 --- a/cmd/iam.go +++ b/cmd/iam.go @@ -32,6 +32,7 @@ import ( "sync/atomic" "time" + libldap "github.com/go-ldap/ldap/v3" "github.com/minio/madmin-go/v3" "github.com/minio/minio-go/v7/pkg/set" "github.com/minio/minio/internal/arn" @@ -48,6 +49,7 @@ import ( xioutil "github.com/minio/minio/internal/ioutil" "github.com/minio/minio/internal/jwt" "github.com/minio/minio/internal/logger" + "github.com/minio/pkg/v2/ldap" "github.com/minio/pkg/v2/policy" etcd "go.etcd.io/etcd/client/v3" ) @@ -1475,6 +1477,164 @@ func (sys *IAMSys) updateGroupMembershipsForLDAP(ctx context.Context) { } } +// NormalizeLDAPAccessKeypairs - normalize the access key pairs (service +// accounts) for LDAP users. This normalizes the parent user and the group names +// whenever the parent user parses validly as a DN. +func (sys *IAMSys) NormalizeLDAPAccessKeypairs(ctx context.Context, accessKeyMap map[string]madmin.SRSvcAccCreate, +) (err error) { + conn, err := sys.LDAPConfig.LDAP.Connect() + if err != nil { + return err + } + defer conn.Close() + + // Bind to the lookup user account + if err = sys.LDAPConfig.LDAP.LookupBind(conn); err != nil { + return err + } + + var collectedErrors []error + updatedKeysMap := make(map[string]madmin.SRSvcAccCreate) + for ak, createReq := range accessKeyMap { + parent := createReq.Parent + groups := createReq.Groups + + _, err := ldap.NormalizeDN(parent) + if err != nil { + // not a valid DN, ignore. + continue + } + + hasDiff := false + + validatedParent, err := sys.LDAPConfig.GetValidatedUserDN(conn, parent) + if err != nil { + collectedErrors = append(collectedErrors, fmt.Errorf("could not validate `%s` exists in LDAP directory: %w", parent, err)) + continue + } + if validatedParent == "" { + err := fmt.Errorf("DN `%s` was not found in the LDAP directory", parent) + collectedErrors = append(collectedErrors, err) + continue + } + + if validatedParent != parent { + hasDiff = true + } + + var validatedGroups []string + for _, group := range groups { + validatedGroup, err := sys.LDAPConfig.GetValidatedGroupDN(conn, group) + if err != nil { + collectedErrors = append(collectedErrors, fmt.Errorf("could not validate `%s` exists in LDAP directory: %w", group, err)) + continue + } + if validatedGroup == "" { + err := fmt.Errorf("DN `%s` was not found in the LDAP directory", group) + collectedErrors = append(collectedErrors, err) + continue + } + + if validatedGroup != group { + hasDiff = true + } + validatedGroups = append(validatedGroups, validatedGroup) + } + + if hasDiff { + updatedCreateReq := createReq + updatedCreateReq.Parent = validatedParent + updatedCreateReq.Groups = validatedGroups + + updatedKeysMap[ak] = updatedCreateReq + } + } + + // if there are any errors, return a collected error. + if len(collectedErrors) > 0 { + return fmt.Errorf("errors validating LDAP DN: %w", errors.Join(collectedErrors...)) + } + + for k, v := range updatedKeysMap { + // Replace the map values with the updated ones + accessKeyMap[k] = v + } + + return nil +} + +// NormalizeLDAPMappingImport - validates the LDAP policy mappings. Keys in the +// given map may not correspond to LDAP DNs - these keys are ignored. +// +// For validated mappings, it updates the key in the given map to be in +// normalized form. +func (sys *IAMSys) NormalizeLDAPMappingImport(ctx context.Context, isGroup bool, + policyMap map[string]MappedPolicy, +) error { + conn, err := sys.LDAPConfig.LDAP.Connect() + if err != nil { + return err + } + defer conn.Close() + + // Bind to the lookup user account + if err = sys.LDAPConfig.LDAP.LookupBind(conn); err != nil { + return err + } + + // We map keys that correspond to LDAP DNs and validate that they exist in + // the LDAP server. + var dnValidator func(*libldap.Conn, string) (string, error) = sys.LDAPConfig.GetValidatedUserDN + if isGroup { + dnValidator = sys.LDAPConfig.GetValidatedGroupDN + } + + // map of normalized DN keys to original keys. + normalizedDNKeysMap := make(map[string][]string) + var collectedErrors []error + for k := range policyMap { + _, err := ldap.NormalizeDN(k) + if err != nil { + // not a valid DN, ignore. + continue + } + validatedDN, err := dnValidator(conn, k) + if err != nil { + collectedErrors = append(collectedErrors, fmt.Errorf("could not validate `%s` exists in LDAP directory: %w", k, err)) + continue + } + if validatedDN == "" { + err := fmt.Errorf("DN `%s` was not found in the LDAP directory", k) + collectedErrors = append(collectedErrors, err) + continue + } + + if validatedDN != k { + normalizedDNKeysMap[validatedDN] = append(normalizedDNKeysMap[validatedDN], k) + } + } + + // if there are any errors, return a collected error. + if len(collectedErrors) > 0 { + return fmt.Errorf("errors validating LDAP DN: %w", errors.Join(collectedErrors...)) + } + + for normKey, origKeys := range normalizedDNKeysMap { + if len(origKeys) > 1 { + return fmt.Errorf("multiple DNs map to the same LDAP DN[%s]: %v; please remove DNs that are not needed", + normKey, origKeys) + } + + // Replacing origKeys[0] with normKey in the policyMap + + // len(origKeys) is always > 0, so here len(origKeys) == 1 + mappingValue := policyMap[origKeys[0]] + delete(policyMap, origKeys[0]) + policyMap[normKey] = mappingValue + } + return nil +} + // GetUser - get user credentials func (sys *IAMSys) GetUser(ctx context.Context, accessKey string) (u UserIdentity, ok bool) { if !sys.Initialized() { @@ -1605,7 +1765,7 @@ func (sys *IAMSys) PolicyDBSet(ctx context.Context, name, policy string, userTyp if sys.LDAPConfig.Enabled() { if isGroup { var foundGroupDN string - if foundGroupDN, err = sys.LDAPConfig.GetValidatedGroupDN(name); err != nil { + if foundGroupDN, err = sys.LDAPConfig.GetValidatedGroupDN(nil, name); err != nil { iamLogIf(ctx, err) return } else if foundGroupDN == "" { @@ -1754,7 +1914,7 @@ func (sys *IAMSys) PolicyDBUpdateLDAP(ctx context.Context, isAttach bool, } else { if isAttach { var foundGroupDN string - if foundGroupDN, err = sys.LDAPConfig.GetValidatedGroupDN(r.Group); err != nil { + if foundGroupDN, err = sys.LDAPConfig.GetValidatedGroupDN(nil, r.Group); err != nil { iamLogIf(ctx, err) return } else if foundGroupDN == "" { diff --git a/cmd/sts-handlers_test.go b/cmd/sts-handlers_test.go index a20ebd53e..0b6bcd1a6 100644 --- a/cmd/sts-handlers_test.go +++ b/cmd/sts-handlers_test.go @@ -28,6 +28,7 @@ import ( "testing" "time" + "github.com/klauspost/compress/zip" "github.com/minio/madmin-go/v3" minio "github.com/minio/minio-go/v7" cr "github.com/minio/minio-go/v7/pkg/credentials" @@ -799,6 +800,122 @@ func TestIAMExportImportWithLDAP(t *testing.T) { } } +func TestIAMImportAssetWithLDAP(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), testDefaultTimeout) + defer cancel() + + exportContentStrings := map[string]string{ + allPoliciesFile: `{"consoleAdmin":{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["admin:*"]},{"Effect":"Allow","Action":["kms:*"]},{"Effect":"Allow","Action":["s3:*"],"Resource":["arn:aws:s3:::*"]}]},"diagnostics":{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["admin:Prometheus","admin:Profiling","admin:ServerTrace","admin:ConsoleLog","admin:ServerInfo","admin:TopLocksInfo","admin:OBDInfo","admin:BandwidthMonitor"],"Resource":["arn:aws:s3:::*"]}]},"readonly":{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:GetBucketLocation","s3:GetObject"],"Resource":["arn:aws:s3:::*"]}]},"readwrite":{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:*"],"Resource":["arn:aws:s3:::*"]}]},"writeonly":{"Version":"2012-10-17","Statement":[{"Effect":"Allow","Action":["s3:PutObject"],"Resource":["arn:aws:s3:::*"]}]}}`, + allUsersFile: `{}`, + allGroupsFile: `{}`, + allSvcAcctsFile: `{ + "u4ccRswj62HV3Ifwima7": { + "parent": "uid=svc.algorithm,OU=swengg,DC=min,DC=io", + "accessKey": "u4ccRswj62HV3Ifwima7", + "secretKey": "ZoEoZdLlzVbOlT9rbhD7ZN7TLyiYXSAlB79uGEge", + "groups": ["cn=project.c,ou=groups,OU=swengg,DC=min,DC=io"], + "claims": { + "accessKey": "u4ccRswj62HV3Ifwima7", + "ldapUser": "uid=svc.algorithm,ou=swengg,dc=min,dc=io", + "ldapUsername": "svc.algorithm", + "parent": "uid=svc.algorithm,ou=swengg,dc=min,dc=io", + "sa-policy": "inherited-policy" + }, + "sessionPolicy": null, + "status": "on", + "name": "", + "description": "" + } +} +`, + userPolicyMappingsFile: `{}`, + groupPolicyMappingsFile: `{ + "cn=project.c,ou=groups,ou=swengg,DC=min,dc=io": { + "version": 0, + "policy": "consoleAdmin", + "updatedAt": "2024-04-17T23:54:28.442998301Z" + } +} +`, + stsUserPolicyMappingsFile: `{ + "uid=dillon,ou=people,OU=swengg,DC=min,DC=io": { + "version": 0, + "policy": "consoleAdmin", + "updatedAt": "2024-04-17T23:54:10.606645642Z" + } +} +`, + } + exportContent := map[string][]byte{} + for k, v := range exportContentStrings { + exportContent[k] = []byte(v) + } + + var importContent []byte + { + var b bytes.Buffer + zipWriter := zip.NewWriter(&b) + rawDataFn := func(r io.Reader, filename string, sz int) error { + header, zerr := zip.FileInfoHeader(dummyFileInfo{ + name: filename, + size: int64(sz), + mode: 0o600, + modTime: time.Now(), + isDir: false, + sys: nil, + }) + if zerr != nil { + adminLogIf(ctx, zerr) + return nil + } + header.Method = zip.Deflate + zwriter, zerr := zipWriter.CreateHeader(header) + if zerr != nil { + adminLogIf(ctx, zerr) + return nil + } + if _, err := io.Copy(zwriter, r); err != nil { + adminLogIf(ctx, err) + } + return nil + } + for _, f := range iamExportFiles { + iamFile := pathJoin(iamAssetsDir, f) + + fileContent, ok := exportContent[f] + if !ok { + t.Fatalf("missing content for %s", f) + } + + if err := rawDataFn(bytes.NewReader(fileContent), iamFile, len(fileContent)); err != nil { + t.Fatalf("failed to write %s: %v", iamFile, err) + } + } + zipWriter.Close() + importContent = b.Bytes() + } + + for i, testCase := range iamTestSuites { + t.Run( + fmt.Sprintf("Test: %d, ServerType: %s", i+1, testCase.ServerTypeDescription), + func(t *testing.T) { + c := &check{t, testCase.serverType} + suite := testCase + + ldapServer := os.Getenv(EnvTestLDAPServer) + if ldapServer == "" { + c.Skipf("Skipping LDAP test as no LDAP server is provided via %s", EnvTestLDAPServer) + } + + suite.SetUpSuite(c) + suite.SetUpLDAP(c, ldapServer) + suite.TestIAMImportAssetContent(c, importContent) + suite.TearDownSuite(c) + }, + ) + } +} + type iamTestContent struct { policies map[string][]byte ldapUserPolicyMappings map[string][]string @@ -856,6 +973,65 @@ type dummyCloser struct { func (d dummyCloser) Close() error { return nil } +func (s *TestSuiteIAM) TestIAMImportAssetContent(c *check, content []byte) { + ctx, cancel := context.WithTimeout(context.Background(), testDefaultTimeout) + defer cancel() + + dummyCloser := dummyCloser{bytes.NewReader(content)} + err := s.adm.ImportIAM(ctx, dummyCloser) + if err != nil { + c.Fatalf("Unable to import IAM: %v", err) + } + + entRes, err := s.adm.GetLDAPPolicyEntities(ctx, madmin.PolicyEntitiesQuery{}) + if err != nil { + c.Fatalf("Unable to get policy entities: %v", err) + } + + expected := madmin.PolicyEntitiesResult{ + PolicyMappings: []madmin.PolicyEntities{ + { + Policy: "consoleAdmin", + Users: []string{"uid=dillon,ou=people,ou=swengg,dc=min,dc=io"}, + Groups: []string{"cn=project.c,ou=groups,ou=swengg,dc=min,dc=io"}, + }, + }, + } + + entRes.Timestamp = time.Time{} + if !reflect.DeepEqual(expected, entRes) { + c.Fatalf("policy entities mismatch: expected: %v, got: %v", expected, entRes) + } + + dn := "uid=svc.algorithm,ou=swengg,dc=min,dc=io" + res, err := s.adm.ListAccessKeysLDAP(ctx, dn, "") + if err != nil { + c.Fatalf("Unable to list access keys: %v", err) + } + + epochTime := time.Unix(0, 0).UTC() + expectedAccKeys := madmin.ListAccessKeysLDAPResp{ + ServiceAccounts: []madmin.ServiceAccountInfo{ + { + AccessKey: "u4ccRswj62HV3Ifwima7", + Expiration: &epochTime, + }, + }, + } + + if !reflect.DeepEqual(expectedAccKeys, res) { + c.Fatalf("access keys mismatch: expected: %v, got: %v", expectedAccKeys, res) + } + + accKeyInfo, err := s.adm.InfoServiceAccount(ctx, "u4ccRswj62HV3Ifwima7") + if err != nil { + c.Fatalf("Unable to get service account info: %v", err) + } + if accKeyInfo.ParentUser != "uid=svc.algorithm,ou=swengg,dc=min,dc=io" { + c.Fatalf("parent mismatch: expected: %s, got: %s", "uid=svc.algorithm,ou=swengg,dc=min,dc=io", accKeyInfo.ParentUser) + } +} + func (s *TestSuiteIAM) TestIAMImport(c *check, exportedContent []byte, caseNum int, content iamTestContent) { ctx, cancel := context.WithTimeout(context.Background(), testDefaultTimeout) defer cancel() diff --git a/internal/config/identity/ldap/ldap.go b/internal/config/identity/ldap/ldap.go index 99e6c87b1..f83d89dbc 100644 --- a/internal/config/identity/ldap/ldap.go +++ b/internal/config/identity/ldap/ldap.go @@ -96,76 +96,67 @@ func (l *Config) GetValidatedDNForUsername(username string) (string, error) { // Since the username is a valid DN, check that it is under a configured // base DN in the LDAP directory. - - // Check that userDN exists in the LDAP directory. - validatedUserDN, err := xldap.LookupDN(conn, username) - if err != nil { - return "", fmt.Errorf("Error looking up user DN %s: %w", username, err) - } - if validatedUserDN == "" { - return "", nil - } - - // This will return an error as the argument is validated to be a DN. - udn, _ := ldap.ParseDN(validatedUserDN) - - // Check that the user DN is under a configured user base DN in the LDAP - // directory. - for _, baseDN := range l.LDAP.UserDNSearchBaseDistNames { - if baseDN.Parsed.AncestorOf(udn) { - return validatedUserDN, nil - } - } - - return "", fmt.Errorf("User DN %s is not under any configured user base DN", validatedUserDN) + return l.GetValidatedUserDN(conn, username) } -// GetValidatedGroupDN checks if the given group DN exists in the LDAP directory -// and returns the group DN sent by the LDAP server. The value returned by the -// server may not be equal to the input group DN, as LDAP equality is not a -// simple Golang string equality. However, we assume the value returned by the -// LDAP server is canonical. +// GetValidatedUserDN validates the given user DN. Will error out if conn is nil. +func (l *Config) GetValidatedUserDN(conn *ldap.Conn, userDN string) (string, error) { + return l.GetValidatedDNUnderBaseDN(conn, userDN, l.LDAP.UserDNSearchBaseDistNames) +} + +// GetValidatedGroupDN validates the given group DN. If conn is nil, creates a +// connection. +func (l *Config) GetValidatedGroupDN(conn *ldap.Conn, groupDN string) (string, error) { + if conn == nil { + var err error + conn, err = l.LDAP.Connect() + if err != nil { + return "", err + } + defer conn.Close() + + // Bind to the lookup user account + if err = l.LDAP.LookupBind(conn); err != nil { + return "", err + } + } + + return l.GetValidatedDNUnderBaseDN(conn, groupDN, l.LDAP.GroupSearchBaseDistNames) +} + +// GetValidatedDNUnderBaseDN checks if the given DN exists in the LDAP directory +// and returns the DN value sent by the LDAP server. The value returned by the +// server may not be equal to the input DN, as LDAP equality is not a simple +// Golang string equality. However, we assume the value returned by the LDAP +// server is canonical. // -// If the group is not found in the LDAP directory, the returned string is empty +// If the DN is not found in the LDAP directory, the returned string is empty // and err = nil. -func (l *Config) GetValidatedGroupDN(groupDN string) (string, error) { - if len(l.LDAP.GroupSearchBaseDistNames) == 0 { - return "", errors.New("no group search Base DNs given") +func (l *Config) GetValidatedDNUnderBaseDN(conn *ldap.Conn, dn string, baseDNList []xldap.BaseDNInfo) (string, error) { + if len(baseDNList) == 0 { + return "", errors.New("no Base DNs given") } - conn, err := l.LDAP.Connect() + // Check that DN exists in the LDAP directory. + validatedDN, err := xldap.LookupDN(conn, dn) if err != nil { - return "", err + return "", fmt.Errorf("Error looking up DN %s: %w", dn, err) } - defer conn.Close() - - // Bind to the lookup user account - if err = l.LDAP.LookupBind(conn); err != nil { - return "", err - } - - // Check that groupDN exists in the LDAP directory. - validatedGroupDN, err := xldap.LookupDN(conn, groupDN) - if err != nil { - return "", fmt.Errorf("Error looking up group DN %s: %w", groupDN, err) - } - if validatedGroupDN == "" { + if validatedDN == "" { return "", nil } - gdn, err := ldap.ParseDN(validatedGroupDN) - if err != nil { - return "", fmt.Errorf("Given group DN %s could not be parsed: %w", validatedGroupDN, err) - } + // This will not return an error as the argument is validated to be a DN. + pdn, _ := ldap.ParseDN(validatedDN) - // Check that the group DN is under a configured group base DN in the LDAP + // Check that the DN is under a configured base DN in the LDAP // directory. - for _, baseDN := range l.LDAP.GroupSearchBaseDistNames { - if baseDN.Parsed.AncestorOf(gdn) { - return validatedGroupDN, nil + for _, baseDN := range baseDNList { + if baseDN.Parsed.AncestorOf(pdn) { + return validatedDN, nil } } - return "", fmt.Errorf("Group DN %s is not under any configured group base DN", validatedGroupDN) + return "", fmt.Errorf("DN %s is not under any configured base DN", validatedDN) } // Bind - binds to ldap, searches LDAP and returns the distinguished name of the