This is an automated email from the ASF dual-hosted git repository.
pbacsko pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/yunikorn-core.git
The following commit(s) were added to refs/heads/master by this push:
new 22d82c19 [YUNIKORN-656] Add LDAP resolver for group resolution (#1021)
22d82c19 is described below
commit 22d82c1966d449529b63a161e22ed90b3803cefc
Author: Mit Desai <[email protected]>
AuthorDate: Thu Oct 30 15:56:52 2025 +0100
[YUNIKORN-656] Add LDAP resolver for group resolution (#1021)
Closes: #1021
Signed-off-by: Peter Bacsko <[email protected]>
---
go.mod | 4 +
go.sum | 24 +
pkg/common/configs/config.go | 18 +-
pkg/common/configs/config_test.go | 50 ++
pkg/common/constants.go | 28 +
pkg/common/security/ldap_validator.go | 381 ++++++++
pkg/common/security/ldap_validator_test.go | 993 +++++++++++++++++++++
pkg/common/security/usergroup.go | 22 +-
pkg/common/security/usergroup_ldap_resolver.go | 383 ++++++++
.../security/usergroup_ldap_resolver_test.go | 708 +++++++++++++++
pkg/common/security/usergroup_no_resolver_test.go | 100 +++
.../usergroup_os_resolver_test.go} | 38 +-
pkg/common/security/usergroup_test.go | 745 ++++++++++------
pkg/common/security/usergroup_test_resolver.go | 3 +
pkg/scheduler/partition.go | 2 +-
15 files changed, 3197 insertions(+), 302 deletions(-)
diff --git a/go.mod b/go.mod
index 8a8b537a..a6036c7a 100644
--- a/go.mod
+++ b/go.mod
@@ -25,6 +25,7 @@ toolchain go1.23.7
require (
github.com/apache/yunikorn-scheduler-interface
v0.0.0-20251021140208-d3b357b98dcd
+ github.com/go-ldap/ldap/v3 v3.4.11
github.com/google/btree v1.1.3
github.com/google/go-cmp v0.7.0
github.com/google/uuid v1.6.0
@@ -44,13 +45,16 @@ require (
)
require (
+ github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358 //
indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.3.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
+ github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667 //
indirect
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect
github.com/petermattis/goid v0.0.0-20250813065127-a731cc31b4fe //
indirect
github.com/prometheus/procfs v0.12.0 // indirect
go.uber.org/multierr v1.10.0 // indirect
+ golang.org/x/crypto v0.41.0 // indirect
golang.org/x/sys v0.35.0 // indirect
golang.org/x/text v0.28.0 // indirect
google.golang.org/genproto/googleapis/rpc
v0.0.0-20250115164207-1a7da9e5054f // indirect
diff --git a/go.sum b/go.sum
index d144cf81..03f92b72 100644
--- a/go.sum
+++ b/go.sum
@@ -1,3 +1,7 @@
+github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358
h1:mFRzDkZVAjdal+s7s0MwaRv9igoPqLRdzOLzw/8Xvq8=
+github.com/Azure/go-ntlmssp v0.0.0-20221128193559-754e69321358/go.mod
h1:chxPXzSsl7ZWRAuOIE23GDNzjWuZquvFlgA8xmpunjU=
+github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa
h1:LHTHcTQiSGT7VVbI0o4wBRNQIgn917usHWOd6VAffYI=
+github.com/alexbrainman/sspi v0.0.0-20231016080023-1a75b4708caa/go.mod
h1:cEWa1LVoE5KvSD9ONXsZrj0z6KqySlCCNKHlLzbqAt4=
github.com/apache/yunikorn-scheduler-interface
v0.0.0-20251021140208-d3b357b98dcd
h1:7HA8EmjMbw81fQpRDRtLAt2i96PKG080ure1V8Bl7K4=
github.com/apache/yunikorn-scheduler-interface
v0.0.0-20251021140208-d3b357b98dcd/go.mod
h1:fQPKbRdD2fYEjjJG9Gjop95NG2/DoJb939XXSxiuu10=
github.com/beorn7/perks v1.0.1 h1:VlbKKnNfV8bJzeqoa4cOKqO6bYr3WgKZxO8Z16+hsOM=
@@ -6,6 +10,10 @@ github.com/cespare/xxhash/v2 v2.3.0
h1:UL815xU9SqsFlibzuggzjXhog7bL6oX9BbNZnL2UF
github.com/cespare/xxhash/v2 v2.3.0/go.mod
h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs=
github.com/davecgh/go-spew v1.1.1
h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod
h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667
h1:BP4M0CvQ4S3TGls2FvczZtj5Re/2ZzkV9VwqPHH/3Bo=
+github.com/go-asn1-ber/asn1-ber v1.5.8-0.20250403174932-29230038a667/go.mod
h1:hEBeB/ic+5LoWskz+yKT7vGhhPYkProFKoKdwZRWMe0=
+github.com/go-ldap/ldap/v3 v3.4.11
h1:4k0Yxweg+a3OyBLjdYn5OKglv18JNvfDykSoI8bW0gU=
+github.com/go-ldap/ldap/v3 v3.4.11/go.mod
h1:bY7t0FLK8OAVpp/vV6sSlpz3EQDGcQwc8pF0ujLgKvM=
github.com/go-logr/logr v1.4.2 h1:6pFjapn8bFcIbiKo3XT4j/BhANplGihG6tvd+8rYgrY=
github.com/go-logr/logr v1.4.2/go.mod
h1:9T104GzyrTigFIr8wt5mBrctHMim0Nb2HLGrmQ40KvY=
github.com/go-logr/stdr v1.2.2 h1:hSWxHoqTgW2S2qGc0LTAI563KZ5YKYRhT3MFKZMbjag=
@@ -18,6 +26,20 @@ github.com/google/go-cmp v0.7.0
h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod
h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod
h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
+github.com/hashicorp/go-uuid v1.0.3
h1:2gKiV6YVmrJ1i2CKKa9obLvRieoRGviZFL26PcT/Co8=
+github.com/hashicorp/go-uuid v1.0.3/go.mod
h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
+github.com/jcmturner/aescts/v2 v2.0.0
h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8=
+github.com/jcmturner/aescts/v2 v2.0.0/go.mod
h1:AiaICIRyfYg35RUkr8yESTqvSy7csK90qZ5xfvvsoNs=
+github.com/jcmturner/dnsutils/v2 v2.0.0
h1:lltnkeZGL0wILNvrNiVCR6Ro5PGU/SeBvVO/8c/iPbo=
+github.com/jcmturner/dnsutils/v2 v2.0.0/go.mod
h1:b0TnjGOvI/n42bZa+hmXL+kFJZsFT7G4t3HTlQ184QM=
+github.com/jcmturner/gofork v1.7.6
h1:QH0l3hzAU1tfT3rZCnW5zXl+orbkNMMRGJfdJjHVETg=
+github.com/jcmturner/gofork v1.7.6/go.mod
h1:1622LH6i/EZqLloHfE7IeZ0uEJwMSUyQ/nDd82IeqRo=
+github.com/jcmturner/goidentity/v6 v6.0.1
h1:VKnZd2oEIMorCTsFBnJWbExfNN7yZr3EhJAxwOkZg6o=
+github.com/jcmturner/goidentity/v6 v6.0.1/go.mod
h1:X1YW3bgtvwAXju7V3LCIMpY0Gbxyjn/mY9zx4tFonSg=
+github.com/jcmturner/gokrb5/v8 v8.4.4
h1:x1Sv4HaTpepFkXbt2IkL29DXRf8sOfZXo8eRKh687T8=
+github.com/jcmturner/gokrb5/v8 v8.4.4/go.mod
h1:1btQEpgT6k+unzCwX1KdWMEwPPkkgBtP+F6aCACiMrs=
+github.com/jcmturner/rpc/v2 v2.0.3
h1:7FXXj8Ti1IaVFpSAziCZWNzbNuZmnvw/i6CqLNdWfZY=
+github.com/jcmturner/rpc/v2 v2.0.3/go.mod
h1:VUJYCIDm3PVOEHw8sgt091/20OJjskO/YJki3ELg/Hc=
github.com/julienschmidt/httprouter v1.3.0
h1:U0609e9tgbseu3rBINet9P48AI/D3oJs4dN7jwJOQ1U=
github.com/julienschmidt/httprouter v1.3.0/go.mod
h1:JR6WtHb+2LUe8TCKY3cZOxFyyO8IZAc4RVcycCCAKdM=
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
@@ -66,6 +88,8 @@ go.uber.org/zap v1.27.0
h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8=
go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E=
go.yaml.in/yaml/v3 v3.0.4 h1:tfq32ie2Jv2UxXFdLJdh3jXuOzWiL1fo0bu/FbuKpbc=
go.yaml.in/yaml/v3 v3.0.4/go.mod
h1:DhzuOOF2ATzADvBadXxruRBLzYTpT36CKvDb3+aBEFg=
+golang.org/x/crypto v0.35.0 h1:b15kiHdrGCHrP6LvwaQ3c03kgNhhiMgvlhxHQhmg2Xs=
+golang.org/x/crypto v0.35.0/go.mod
h1:dy7dXNW32cAb/6/PRuTNsix8T+vJAqvuIy5Bli/x0YQ=
golang.org/x/exp v0.0.0-20250228200357-dead58393ab7
h1:aWwlzYV971S4BXRS9AmqwDLAD85ouC6X+pocatKY58c=
golang.org/x/exp v0.0.0-20250228200357-dead58393ab7/go.mod
h1:BHOTPb3L19zxehTsLoJXVaTktb06DFgmdW6Wb9s8jqk=
golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE=
diff --git a/pkg/common/configs/config.go b/pkg/common/configs/config.go
index ce91d9cc..32a4ef2d 100644
--- a/pkg/common/configs/config.go
+++ b/pkg/common/configs/config.go
@@ -45,13 +45,19 @@ type SchedulerConfig struct {
// - a list of placement rule definition objects
// - a list of users specifying limits on the partition
// - the preemption configuration for the partition
+// - user group resolver type (os, ldap, "")
type PartitionConfig struct {
- Name string
- Queues []QueueConfig
- PlacementRules []PlacementRule `yaml:",omitempty"
json:",omitempty"`
- Limits []Limit `yaml:",omitempty"
json:",omitempty"`
- Preemption PartitionPreemptionConfig `yaml:",omitempty"
json:",omitempty"`
- NodeSortPolicy NodeSortingPolicy `yaml:",omitempty"
json:",omitempty"`
+ Name string
+ Queues []QueueConfig
+ PlacementRules []PlacementRule `yaml:",omitempty"
json:",omitempty"`
+ Limits []Limit `yaml:",omitempty"
json:",omitempty"`
+ Preemption PartitionPreemptionConfig `yaml:",omitempty"
json:",omitempty"`
+ NodeSortPolicy NodeSortingPolicy `yaml:",omitempty"
json:",omitempty"`
+ UserGroupResolver UserGroupResolver `yaml:",omitempty"
json:",omitempty"`
+}
+
+type UserGroupResolver struct {
+ Type string `yaml:"type,omitempty" json:"type,omitempty"`
}
// The partition preemption configuration
diff --git a/pkg/common/configs/config_test.go
b/pkg/common/configs/config_test.go
index f0710b59..53f737c0 100644
--- a/pkg/common/configs/config_test.go
+++ b/pkg/common/configs/config_test.go
@@ -2181,3 +2181,53 @@ partitions:
_, err = CreateConfig(data)
assert.ErrorContains(t, err, "group * max resource map[memory:90000
vcore:100000] of queue leaf is greater than immediate or ancestor parent
maximum resource map[memory:10000 vcore:10000000]")
}
+
+// TestUserGroupResolverConfig: tests the user group resolver configuration
+func TestUserGroupResolverConfig(t *testing.T) {
+ data := `
+partitions:
+ -
+ name: default
+ usergroupresolver:
+ type: ldap
+ placementrules:
+ - name: tag
+ value: namespace
+ create: true
+ queues:
+ - name: root
+ submitacl: '*'
+ properties:
+ application.sort.policy: fifo
+ sample: value2
+`
+ // validate the config and check after the update
+ config, err := CreateConfig(data)
+ assert.NilError(t, err)
+
+ // check if the user group resolver is set correctly
+ assert.Equal(t, "ldap", config.Partitions[0].UserGroupResolver.Type)
+
+ // partition with no user group resolver
+ data = `
+partitions:
+ -
+ name: default
+ placementrules:
+ - name: tag
+ value: namespace
+ create: true
+ queues:
+ - name: root
+ submitacl: '*'
+ properties:
+ application.sort.policy: fifo
+ sample: value2
+`
+ // validate the config and check after the update
+ config, err = CreateConfig(data)
+ assert.NilError(t, err)
+
+ // check if the user group resolver is set to empty
+ assert.Equal(t, "", config.Partitions[0].UserGroupResolver.Type)
+}
diff --git a/pkg/common/constants.go b/pkg/common/constants.go
index 7eae1c79..04bc2de3 100644
--- a/pkg/common/constants.go
+++ b/pkg/common/constants.go
@@ -29,4 +29,32 @@ const (
RecoveryQueue = "@recovery@"
RecoveryQueueFull = "root." + RecoveryQueue
DefaultPlacementQueue = "root.default"
+ LdapHost = "Host"
+ LdapPort = "Port"
+ LdapBaseDN = "BaseDN"
+ LdapFilter = "Filter"
+ LdapGroupAttr = "GroupAttr"
+ LdapReturnAttr = "ReturnAttr"
+ LdapBindUser = "BindUser"
+ LdapBindPassword = "BindPassword"
+ LdapInsecure = "Insecure"
+ LdapSSL = "SSL"
+)
+
+const (
+ DefaultLdapHost = "localhost"
+ DefaultLdapPort = 389
+ DefaultLdapBaseDN = "dc=example,dc=com"
+ DefaultLdapFilter = "(&(sAMAccountName=%s))"
+ DefaultLdapGroupAttr = "memberOf"
+ DefaultLdapBindUser = "admin"
+ DefaultLdapBindPassword = "admin"
+ DefaultLdapInsecure = false
+ DefaultLdapSSL = false
+ DefaultLdapUserUID = "1211"
+)
+
+var (
+ LdapMountPath = "/run/secrets/ldap"
+ DefaultLdapReturnAttr = []string{"memberOf"}
)
diff --git a/pkg/common/security/ldap_validator.go
b/pkg/common/security/ldap_validator.go
new file mode 100644
index 00000000..40d255de
--- /dev/null
+++ b/pkg/common/security/ldap_validator.go
@@ -0,0 +1,381 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package security
+
+import (
+ "fmt"
+ "net"
+ "regexp"
+ "strconv"
+ "strings"
+
+ "go.uber.org/zap"
+
+ "github.com/apache/yunikorn-core/pkg/common"
+ "github.com/apache/yunikorn-core/pkg/log"
+)
+
+// ValidationLevel defines the severity of validation issues
+type ValidationLevel int
+
+const (
+ // ValidationWarning indicates a non-critical issue that allows
operation but might cause problems
+ ValidationWarning ValidationLevel = iota
+ // ValidationError indicates a critical issue that prevents proper
operation
+ ValidationError
+)
+
+// ValidationIssue represents a single validation problem
+type ValidationIssue struct {
+ Field string
+ Message string
+ Level ValidationLevel
+}
+
+// LdapValidator provides validation for LDAP configuration
+type LdapValidator struct {
+ issues []ValidationIssue
+}
+
+// NewLdapValidator creates a new validator instance
+func NewLdapValidator() *LdapValidator {
+ return &LdapValidator{
+ issues: make([]ValidationIssue, 0),
+ }
+}
+
+// ValidateConfig validates the entire LDAP configuration
+func (v *LdapValidator) ValidateConfig(config *LdapConfig) bool {
+ v.validateHost(config.Host)
+ v.validatePort(config.Port)
+ v.validateBaseDN(config.BaseDN)
+ v.validateFilter(config.Filter)
+ v.validateGroupAttr(config.GroupAttr)
+ v.validateReturnAttr(config.ReturnAttr)
+ v.validateBindUser(config.BindUser)
+ v.validateBindPassword(config.BindPassword)
+
+ // Consistency checks
+ v.validateConsistency(config)
+
+ // Log all issues
+ v.logIssues()
+
+ // Return true if no errors (warnings are acceptable)
+ return !v.hasErrors()
+}
+
+// validateHost validates the LDAP host
+func (v *LdapValidator) validateHost(host string) {
+ if host == "" {
+ v.addIssue("Host", "Host cannot be empty", ValidationError)
+ return
+ }
+
+ // Check if it's an IP address
+ if net.ParseIP(host) != nil {
+ return // Valid IP address
+ }
+
+ // Check if it's a valid hostname
+ hostnameRegex :=
regexp.MustCompile(`^(([a-zA-Z0-9]|[a-zA-Z0-9][a-zA-Z0-9\-]*[a-zA-Z0-9])\.)*([A-Za-z0-9]|[A-Za-z0-9][A-Za-z0-9\-]*[A-Za-z0-9])$`)
+ if !hostnameRegex.MatchString(host) {
+ v.addIssue("Host", fmt.Sprintf("Invalid hostname format: %s",
host), ValidationWarning)
+ }
+}
+
+// validatePort validates the LDAP port
+func (v *LdapValidator) validatePort(port int) {
+ if port < 1 || port > 65535 {
+ v.addIssue("Port", fmt.Sprintf("Port must be between 1 and
65535, got: %d", port), ValidationError)
+ }
+}
+
+// validateBaseDN validates the LDAP base DN
+func (v *LdapValidator) validateBaseDN(baseDN string) {
+ if baseDN == "" {
+ v.addIssue("BaseDN", "BaseDN cannot be empty", ValidationError)
+ return
+ }
+
+ // Check for at least one domain component
+ if !strings.Contains(strings.ToLower(baseDN), "dc=") {
+ v.addIssue("BaseDN", "BaseDN should contain at least one domain
component (dc=)", ValidationWarning)
+ }
+
+ // Check for valid DN format
+ dnRegex :=
regexp.MustCompile(`^(?:(?:[a-zA-Z0-9]+=[^,]+)(?:,(?:[a-zA-Z0-9]+=[^,]+))*)?$`)
+ if !dnRegex.MatchString(baseDN) {
+ v.addIssue("BaseDN", fmt.Sprintf("Invalid DN format: %s",
baseDN), ValidationWarning)
+ }
+}
+
+// validateFilter validates the LDAP filter
+func (v *LdapValidator) validateFilter(filter string) {
+ if filter == "" {
+ v.addIssue("Filter", "Filter cannot be empty", ValidationError)
+ return
+ }
+
+ // Check for username placeholder
+ if !strings.Contains(filter, "%s") {
+ v.addIssue("Filter", "Filter must contain '%s' placeholder for
username substitution", ValidationError)
+ }
+
+ // Check for balanced parentheses
+ if !hasBalancedParentheses(filter) {
+ v.addIssue("Filter", "Filter has unbalanced parentheses",
ValidationError)
+ }
+
+ // Basic filter format check
+ filterRegex := regexp.MustCompile(`^\(.*\)$`)
+ if !filterRegex.MatchString(filter) {
+ v.addIssue("Filter", "Filter should be enclosed in
parentheses", ValidationWarning)
+ }
+}
+
+// validateGroupAttr validates the LDAP group attribute
+func (v *LdapValidator) validateGroupAttr(groupAttr string) {
+ if groupAttr == "" {
+ v.addIssue("GroupAttr", "GroupAttr cannot be empty",
ValidationError)
+ return
+ }
+
+ // Check for valid attribute name format
+ attrRegex := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9\-_]*$`)
+ if !attrRegex.MatchString(groupAttr) {
+ v.addIssue("GroupAttr", fmt.Sprintf("Invalid attribute name
format: %s", groupAttr), ValidationWarning)
+ }
+}
+
+// validateReturnAttr validates the LDAP return attributes
+func (v *LdapValidator) validateReturnAttr(returnAttr []string) {
+ if len(returnAttr) == 0 {
+ v.addIssue("ReturnAttr", "ReturnAttr cannot be empty",
ValidationError)
+ return
+ }
+
+ attrRegex := regexp.MustCompile(`^[a-zA-Z][a-zA-Z0-9\-_]*$`)
+ for _, attr := range returnAttr {
+ if !attrRegex.MatchString(attr) {
+ v.addIssue("ReturnAttr", fmt.Sprintf("Invalid attribute
name format: %s", attr), ValidationWarning)
+ }
+ }
+}
+
+// validateBindUser validates the LDAP bind user
+func (v *LdapValidator) validateBindUser(bindUser string) {
+ if bindUser == "" {
+ v.addIssue("BindUser", "BindUser cannot be empty",
ValidationError)
+ return
+ }
+
+ // Check if it's a DN format
+ dnRegex :=
regexp.MustCompile(`^(?:(?:[a-zA-Z0-9]+=[^,]+)(?:,(?:[a-zA-Z0-9]+=[^,]+))*)?$`)
+ if dnRegex.MatchString(bindUser) {
+ return // Valid DN format
+ }
+
+ // Check if it's a username format
+ usernameRegex := regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9\._\-@]*$`)
+ if !usernameRegex.MatchString(bindUser) {
+ v.addIssue("BindUser", fmt.Sprintf("BindUser is neither a valid
DN nor a valid username: %s", bindUser), ValidationWarning)
+ }
+}
+
+// validateBindPassword validates the LDAP bind password
+func (v *LdapValidator) validateBindPassword(bindPassword string) {
+ if bindPassword == "" {
+ v.addIssue("BindPassword", "BindPassword cannot be empty",
ValidationError)
+ return
+ }
+
+ // Check for minimum length
+ if len(bindPassword) < 3 {
+ v.addIssue("BindPassword", "BindPassword is too short",
ValidationWarning)
+ }
+
+ // We don't check for password complexity here as it depends on the
LDAP server policy
+}
+
+// validateConsistency performs cross-field validation
+func (v *LdapValidator) validateConsistency(config *LdapConfig) {
+ // Check SSL and port consistency
+ if config.useSsl && config.Port != 636 {
+ v.addIssue("Port", fmt.Sprintf("SSL is enabled but port is not
the default LDAPS port (636), using: %d", config.Port), ValidationWarning)
+ }
+
+ // Check SSL and Insecure consistency
+ if config.useSsl && config.Insecure {
+ v.addIssue("SSL/Insecure", "Both SSL and Insecure are enabled,
which may indicate a security misconfiguration", ValidationWarning)
+ }
+}
+
+// addIssue adds a validation issue to the list
+func (v *LdapValidator) addIssue(field, message string, level ValidationLevel)
{
+ v.issues = append(v.issues, ValidationIssue{
+ Field: field,
+ Message: message,
+ Level: level,
+ })
+}
+
+// hasErrors checks if there are any validation errors
+func (v *LdapValidator) hasErrors() bool {
+ for _, issue := range v.issues {
+ if issue.Level == ValidationError {
+ return true
+ }
+ }
+ return false
+}
+
+// logIssues logs all validation issues
+func (v *LdapValidator) logIssues() {
+ for _, issue := range v.issues {
+ if issue.Level == ValidationError {
+ log.Log(log.Security).Error("LDAP configuration
validation error",
+ zap.String("field", issue.Field),
+ zap.String("message", issue.Message))
+ } else {
+ log.Log(log.Security).Warn("LDAP configuration
validation warning",
+ zap.String("field", issue.Field),
+ zap.String("message", issue.Message))
+ }
+ }
+}
+
+// hasBalancedParentheses checks if a string has balanced parentheses
+func hasBalancedParentheses(s string) bool {
+ count := 0
+ for _, c := range s {
+ switch c {
+ case '(':
+ count++
+ case ')':
+ count--
+ if count < 0 {
+ return false
+ }
+ }
+ }
+ return count == 0
+}
+
+// ValidateSecretValue validates a single secret value based on its key
+func ValidateSecretValue(key, value string) (interface{}, error) {
+ switch key {
+ case common.LdapHost:
+ return validateHostValue(value)
+ case common.LdapPort:
+ return validatePortValue(value)
+ case common.LdapBaseDN:
+ return validateBaseDNValue(value)
+ case common.LdapFilter:
+ return validateFilterValue(value)
+ case common.LdapGroupAttr:
+ return validateGroupAttrValue(value)
+ case common.LdapReturnAttr:
+ return validateReturnAttrValue(value)
+ case common.LdapBindUser:
+ return validateBindUserValue(value)
+ case common.LdapBindPassword:
+ return validateBindPasswordValue(value)
+ case common.LdapInsecure, common.LdapSSL:
+ return validateBoolValue(value)
+ default:
+ return nil, fmt.Errorf("unknown LDAP secret key: %s", key)
+ }
+}
+
+// Individual validation functions for each secret type
+func validateHostValue(value string) (string, error) {
+ if value == "" {
+ return "", fmt.Errorf("host cannot be empty")
+ }
+ return value, nil
+}
+
+func validatePortValue(value string) (int, error) {
+ port, err := strconv.Atoi(value)
+ if err != nil {
+ return 0, fmt.Errorf("invalid port number: %s", err)
+ }
+ if port < 1 || port > 65535 {
+ return 0, fmt.Errorf("port must be between 1 and 65535, got:
%d", port)
+ }
+ return port, nil
+}
+
+func validateBaseDNValue(value string) (string, error) {
+ if value == "" {
+ return "", fmt.Errorf("baseDN cannot be empty")
+ }
+ return value, nil
+}
+
+func validateFilterValue(value string) (string, error) {
+ if value == "" {
+ return "", fmt.Errorf("filter cannot be empty")
+ }
+ if !strings.Contains(value, "%s") {
+ return "", fmt.Errorf("filter must contain '%%s' placeholder
for username substitution")
+ }
+ if !hasBalancedParentheses(value) {
+ return "", fmt.Errorf("filter has unbalanced parentheses")
+ }
+ return value, nil
+}
+
+func validateGroupAttrValue(value string) (string, error) {
+ if value == "" {
+ return "", fmt.Errorf("groupAttr cannot be empty")
+ }
+ return value, nil
+}
+
+func validateReturnAttrValue(value string) ([]string, error) {
+ if value == "" {
+ return nil, fmt.Errorf("returnAttr cannot be empty")
+ }
+ attrs := strings.Split(value, ",")
+ return attrs, nil
+}
+
+func validateBindUserValue(value string) (string, error) {
+ if value == "" {
+ return "", fmt.Errorf("bindUser cannot be empty")
+ }
+ return value, nil
+}
+
+func validateBindPasswordValue(value string) (string, error) {
+ if value == "" {
+ return "", fmt.Errorf("bindPassword cannot be empty")
+ }
+ return value, nil
+}
+
+func validateBoolValue(value string) (bool, error) {
+ boolValue, err := strconv.ParseBool(value)
+ if err != nil {
+ return false, fmt.Errorf("invalid boolean value: %s", err)
+ }
+ return boolValue, nil
+}
diff --git a/pkg/common/security/ldap_validator_test.go
b/pkg/common/security/ldap_validator_test.go
new file mode 100644
index 00000000..c9f7f7f0
--- /dev/null
+++ b/pkg/common/security/ldap_validator_test.go
@@ -0,0 +1,993 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package security
+
+import (
+ "testing"
+
+ "gotest.tools/v3/assert"
+
+ "github.com/apache/yunikorn-core/pkg/common"
+)
+
+func TestValidateHostValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ wantError bool
+ }{
+ {"Valid hostname", "ldap.example.com", false},
+ {"Valid IP", "192.168.1.1", false},
+ {"Empty", "", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := validateHostValue(tt.value)
+ if tt.wantError {
+ assert.Assert(t, err != nil)
+ } else {
+ assert.NilError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidatePortValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ wantError bool
+ }{
+ {"Valid port", "389", false},
+ {"Valid port range", "65535", false},
+ {"Invalid port - too high", "65536", true},
+ {"Invalid port - too low", "0", true},
+ {"Invalid port - not a number", "abc", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := validatePortValue(tt.value)
+ if tt.wantError {
+ assert.Assert(t, err != nil)
+ } else {
+ assert.NilError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateFilterValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ wantError bool
+ }{
+ {"Valid filter", "(&(objectClass=user)(sAMAccountName=%s))",
false},
+ {"Missing placeholder",
"(&(objectClass=user)(sAMAccountName=user))", true},
+ {"Unbalanced parentheses",
"(&(objectClass=user)(sAMAccountName=%s)", true},
+ {"Empty", "", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := validateFilterValue(tt.value)
+ if tt.wantError {
+ assert.Assert(t, err != nil)
+ } else {
+ assert.NilError(t, err)
+ }
+ })
+ }
+}
+
+func TestValidateReturnAttrValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ wantError bool
+ }{
+ {"Valid single attr", "memberOf", false},
+ {"Valid multiple attrs", "memberOf,cn,mail", false},
+ {"Empty", "", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ attrs, err := validateReturnAttrValue(tt.value)
+ if tt.wantError {
+ assert.Assert(t, err != nil)
+ } else {
+ assert.NilError(t, err)
+ switch tt.value {
+ case "memberOf":
+ assert.Equal(t, 1, len(attrs))
+ assert.Equal(t, "memberOf", attrs[0])
+ case "memberOf,cn,mail":
+ assert.Equal(t, 3, len(attrs))
+ assert.Equal(t, "memberOf", attrs[0])
+ assert.Equal(t, "cn", attrs[1])
+ assert.Equal(t, "mail", attrs[2])
+ }
+ }
+ })
+ }
+}
+
+func TestValidateBoolValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ expected bool
+ wantError bool
+ }{
+ {"Valid true", "true", true, false},
+ {"Valid false", "false", false, false},
+ {"Valid 1", "1", true, false},
+ {"Valid 0", "0", false, false},
+ {"Invalid", "notabool", false, true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ val, err := validateBoolValue(tt.value)
+ if tt.wantError {
+ assert.Assert(t, err != nil)
+ } else {
+ assert.NilError(t, err)
+ assert.Equal(t, tt.expected, val)
+ }
+ })
+ }
+}
+
+func TestValidateConfig(t *testing.T) {
+ tests := []struct {
+ name string
+ config *LdapConfig
+ expected bool
+ }{
+ {
+ name: "Valid configuration",
+ config: &LdapConfig{
+ Host: "ldap.example.com",
+ Port: 389,
+ BaseDN: "dc=example,dc=com",
+ Filter:
"(&(objectClass=user)(sAMAccountName=%s))",
+ GroupAttr: "memberOf",
+ ReturnAttr: []string{"memberOf"},
+ BindUser: "cn=admin,dc=example,dc=com",
+ BindPassword: "password",
+ Insecure: false,
+ useSsl: false,
+ },
+ expected: true,
+ },
+ {
+ name: "Invalid configuration - empty fields",
+ config: &LdapConfig{
+ Host: "",
+ Port: 0,
+ BaseDN: "",
+ Filter: "invalid-filter",
+ GroupAttr: "",
+ ReturnAttr: []string{},
+ BindUser: "",
+ BindPassword: "",
+ Insecure: true,
+ useSsl: true,
+ },
+ expected: false,
+ },
+ {
+ name: "Invalid configuration - missing placeholder in
filter",
+ config: &LdapConfig{
+ Host: "ldap.example.com",
+ Port: 389,
+ BaseDN: "dc=example,dc=com",
+ Filter:
"(&(objectClass=user)(sAMAccountName=user))", // Missing %s placeholder
+ GroupAttr: "memberOf",
+ ReturnAttr: []string{"memberOf"},
+ BindUser: "cn=admin,dc=example,dc=com",
+ BindPassword: "password",
+ Insecure: false,
+ useSsl: false,
+ },
+ expected: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ validator := NewLdapValidator()
+ result := validator.ValidateConfig(tt.config)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+// TestValidateBaseDN tests the validateBaseDN method with various inputs
+func TestValidateBaseDN(t *testing.T) {
+ tests := []struct {
+ name string
+ baseDN string
+ expectWarning bool
+ expectError bool
+ }{
+ {
+ name: "Valid BaseDN",
+ baseDN: "dc=example,dc=com",
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Empty BaseDN",
+ baseDN: "",
+ expectWarning: false,
+ expectError: true,
+ },
+ {
+ name: "Invalid format - missing value",
+ baseDN: "dc=,dc=com",
+ expectWarning: true,
+ expectError: false,
+ },
+ {
+ name: "Invalid format - missing equals",
+ baseDN: "dcexample,dc=com",
+ expectWarning: true,
+ expectError: false,
+ },
+ {
+ name: "Invalid format - unbalanced commas",
+ baseDN: "dc=example,dc=com,",
+ expectWarning: true,
+ expectError: false,
+ },
+ {
+ name: "No domain component",
+ baseDN: "cn=admin,ou=users",
+ expectWarning: true,
+ expectError: false,
+ },
+ {
+ name: "Invalid format - extra comma",
+ baseDN: "dc=example,,dc=com",
+ expectWarning: true,
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ validator := NewLdapValidator()
+ validator.validateBaseDN(tt.baseDN)
+
+ hasError := false
+ hasWarning := false
+
+ for _, issue := range validator.issues {
+ if issue.Field == "BaseDN" {
+ switch issue.Level {
+ case ValidationError:
+ hasError = true
+ case ValidationWarning:
+ hasWarning = true
+ }
+ }
+ }
+
+ assert.Equal(t, tt.expectError, hasError, "Expected
error: %v, got: %v", tt.expectError, hasError)
+ assert.Equal(t, tt.expectWarning, hasWarning, "Expected
warning: %v, got: %v", tt.expectWarning, hasWarning)
+ })
+ }
+}
+
+// TestValidateGroupAttr tests the validateGroupAttr method with various inputs
+func TestValidateGroupAttr(t *testing.T) {
+ tests := []struct {
+ name string
+ groupAttr string
+ expectWarning bool
+ expectError bool
+ }{
+ {
+ name: "Valid attribute name",
+ groupAttr: "memberOf",
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Empty attribute name",
+ groupAttr: "",
+ expectWarning: false,
+ expectError: true,
+ },
+ {
+ name: "Invalid format - starts with number",
+ groupAttr: "1memberOf",
+ expectWarning: true,
+ expectError: false,
+ },
+ {
+ name: "Invalid format - special characters",
+ groupAttr: "member@Of",
+ expectWarning: true,
+ expectError: false,
+ },
+ {
+ name: "Valid with hyphen",
+ groupAttr: "member-of",
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Valid with underscore",
+ groupAttr: "member_of",
+ expectWarning: false,
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ validator := NewLdapValidator()
+ validator.validateGroupAttr(tt.groupAttr)
+
+ hasError := false
+ hasWarning := false
+
+ for _, issue := range validator.issues {
+ if issue.Field == "GroupAttr" {
+ switch issue.Level {
+ case ValidationError:
+ hasError = true
+ case ValidationWarning:
+ hasWarning = true
+ }
+ }
+ }
+
+ assert.Equal(t, tt.expectError, hasError, "Expected
error: %v, got: %v", tt.expectError, hasError)
+ assert.Equal(t, tt.expectWarning, hasWarning, "Expected
warning: %v, got: %v", tt.expectWarning, hasWarning)
+ })
+ }
+}
+
+// TestValidateBindUser tests the validateBindUser method with various inputs
+func TestValidateBindUser(t *testing.T) {
+ tests := []struct {
+ name string
+ bindUser string
+ expectWarning bool
+ expectError bool
+ }{
+ {
+ name: "Valid DN format",
+ bindUser: "cn=admin,dc=example,dc=com",
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Valid username format",
+ bindUser: "admin",
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Valid username with domain",
+ bindUser: "[email protected]",
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Empty bind user",
+ bindUser: "",
+ expectWarning: false,
+ expectError: true,
+ },
+ {
+ name: "Invalid format - special characters",
+ bindUser: "admin!#$%",
+ expectWarning: true,
+ expectError: false,
+ },
+ {
+ name: "Invalid DN format",
+ bindUser: "cn=admin,dc=example,=com",
+ expectWarning: true,
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ validator := NewLdapValidator()
+ validator.validateBindUser(tt.bindUser)
+
+ hasError := false
+ hasWarning := false
+
+ for _, issue := range validator.issues {
+ if issue.Field == "BindUser" {
+ switch issue.Level {
+ case ValidationError:
+ hasError = true
+ case ValidationWarning:
+ hasWarning = true
+ }
+ }
+ }
+
+ assert.Equal(t, tt.expectError, hasError, "Expected
error: %v, got: %v", tt.expectError, hasError)
+ assert.Equal(t, tt.expectWarning, hasWarning, "Expected
warning: %v, got: %v", tt.expectWarning, hasWarning)
+ })
+ }
+}
+
+// TestValidateBindPassword tests the validateBindPassword method with various
inputs
+func TestValidateBindPassword(t *testing.T) {
+ tests := []struct {
+ name string
+ bindPassword string
+ expectWarning bool
+ expectError bool
+ }{
+ {
+ name: "Valid password",
+ bindPassword: "password123",
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Empty password",
+ bindPassword: "",
+ expectWarning: false,
+ expectError: true,
+ },
+ {
+ name: "Very short password",
+ bindPassword: "a",
+ expectWarning: true,
+ expectError: false,
+ },
+ {
+ name: "Password with special characters",
+ bindPassword: "p@ssw0rd!",
+ expectWarning: false,
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ validator := NewLdapValidator()
+ validator.validateBindPassword(tt.bindPassword)
+
+ hasError := false
+ hasWarning := false
+
+ for _, issue := range validator.issues {
+ if issue.Field == "BindPassword" {
+ switch issue.Level {
+ case ValidationError:
+ hasError = true
+ case ValidationWarning:
+ hasWarning = true
+ }
+ }
+ }
+
+ assert.Equal(t, tt.expectError, hasError, "Expected
error: %v, got: %v", tt.expectError, hasError)
+ assert.Equal(t, tt.expectWarning, hasWarning, "Expected
warning: %v, got: %v", tt.expectWarning, hasWarning)
+ })
+ }
+}
+
+// TestValidateReturnAttr tests the validateReturnAttr method with various
inputs
+func TestValidateReturnAttr(t *testing.T) {
+ tests := []struct {
+ name string
+ returnAttr []string
+ expectWarning bool
+ expectError bool
+ }{
+ {
+ name: "Valid single attribute",
+ returnAttr: []string{"memberOf"},
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Valid multiple attributes",
+ returnAttr: []string{"memberOf", "cn", "mail"},
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Empty array",
+ returnAttr: []string{},
+ expectWarning: false,
+ expectError: true,
+ },
+ {
+ name: "Invalid attribute name",
+ returnAttr: []string{"member@Of"},
+ expectWarning: true,
+ expectError: false,
+ },
+ {
+ name: "Mix of valid and invalid",
+ returnAttr: []string{"memberOf", "123invalid"},
+ expectWarning: true,
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ validator := NewLdapValidator()
+ validator.validateReturnAttr(tt.returnAttr)
+
+ hasError := false
+ hasWarning := false
+
+ for _, issue := range validator.issues {
+ if issue.Field == "ReturnAttr" {
+ switch issue.Level {
+ case ValidationError:
+ hasError = true
+ case ValidationWarning:
+ hasWarning = true
+ }
+ }
+ }
+
+ assert.Equal(t, tt.expectError, hasError, "Expected
error: %v, got: %v", tt.expectError, hasError)
+ assert.Equal(t, tt.expectWarning, hasWarning, "Expected
warning: %v, got: %v", tt.expectWarning, hasWarning)
+ })
+ }
+}
+
+// TestValidateHost tests the validateHost method with various inputs
+func TestValidateHost(t *testing.T) {
+ tests := []struct {
+ name string
+ host string
+ expectWarning bool
+ expectError bool
+ }{
+ {
+ name: "Valid hostname",
+ host: "ldap.example.com",
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Valid IP address",
+ host: "192.168.1.1",
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Empty host",
+ host: "",
+ expectWarning: false,
+ expectError: true,
+ },
+ {
+ name: "Invalid hostname - starts with hyphen",
+ host: "-ldap.example.com",
+ expectWarning: true,
+ expectError: false,
+ },
+ {
+ name: "Invalid hostname - contains invalid
characters",
+ host: "ldap_example.com",
+ expectWarning: true,
+ expectError: false,
+ },
+ {
+ name: "Invalid hostname - double dots",
+ host: "ldap..example.com",
+ expectWarning: true,
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ validator := NewLdapValidator()
+ validator.validateHost(tt.host)
+
+ hasError := false
+ hasWarning := false
+
+ for _, issue := range validator.issues {
+ if issue.Field == "Host" {
+ switch issue.Level {
+ case ValidationError:
+ hasError = true
+ case ValidationWarning:
+ hasWarning = true
+ }
+ }
+ }
+
+ assert.Equal(t, tt.expectError, hasError, "Expected
error: %v, got: %v", tt.expectError, hasError)
+ assert.Equal(t, tt.expectWarning, hasWarning, "Expected
warning: %v, got: %v", tt.expectWarning, hasWarning)
+ })
+ }
+}
+
+// TestValidateFilter tests the validateFilter method with various inputs
+func TestValidateFilter(t *testing.T) {
+ tests := []struct {
+ name string
+ filter string
+ expectWarning bool
+ expectError bool
+ }{
+ {
+ name: "Valid filter",
+ filter:
"(&(objectClass=user)(sAMAccountName=%s))",
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Empty filter",
+ filter: "",
+ expectWarning: false,
+ expectError: true,
+ },
+ {
+ name: "Missing placeholder",
+ filter:
"(&(objectClass=user)(sAMAccountName=user))",
+ expectWarning: false,
+ expectError: true,
+ },
+ {
+ name: "Unbalanced parentheses",
+ filter:
"(&(objectClass=user)(sAMAccountName=%s)",
+ expectWarning: false,
+ expectError: true,
+ },
+ {
+ name: "Not enclosed in parentheses",
+ filter: "objectClass=user&sAMAccountName=%s",
+ expectWarning: true,
+ expectError: false, // This filter does contain the
%s placeholder, so it's not an error
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ validator := NewLdapValidator()
+ validator.validateFilter(tt.filter)
+
+ hasError := false
+ hasWarning := false
+
+ for _, issue := range validator.issues {
+ if issue.Field == "Filter" {
+ switch issue.Level {
+ case ValidationError:
+ hasError = true
+ case ValidationWarning:
+ hasWarning = true
+ }
+ }
+ }
+
+ assert.Equal(t, tt.expectError, hasError, "Expected
error: %v, got: %v", tt.expectError, hasError)
+ assert.Equal(t, tt.expectWarning, hasWarning, "Expected
warning: %v, got: %v", tt.expectWarning, hasWarning)
+ })
+ }
+}
+
+func TestValidateSecretValue(t *testing.T) {
+ tests := []struct {
+ name string
+ key string
+ value string
+ wantError bool
+ }{
+ {"Valid host", common.LdapHost, "ldap.example.com", false},
+ {"Valid port", common.LdapPort, "389", false},
+ {"Valid baseDN", common.LdapBaseDN, "dc=example,dc=com", false},
+ {"Valid filter", common.LdapFilter,
"(&(objectClass=user)(sAMAccountName=%s))", false},
+ {"Valid groupAttr", common.LdapGroupAttr, "memberOf", false},
+ {"Valid returnAttr", common.LdapReturnAttr, "memberOf,cn",
false},
+ {"Valid bindUser", common.LdapBindUser,
"cn=admin,dc=example,dc=com", false},
+ {"Valid bindPassword", common.LdapBindPassword, "password",
false},
+ {"Valid insecure", common.LdapInsecure, "true", false},
+ {"Valid SSL", common.LdapSSL, "false", false},
+ {"Invalid key", "unknown", "value", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := ValidateSecretValue(tt.key, tt.value)
+ if tt.wantError {
+ assert.Assert(t, err != nil)
+ } else {
+ assert.NilError(t, err)
+ }
+ })
+ }
+}
+
+func TestHasBalancedParentheses(t *testing.T) {
+ tests := []struct {
+ name string
+ input string
+ expected bool
+ }{
+ {"Empty string", "", true},
+ {"Simple balanced", "()", true},
+ {"Nested balanced", "((()))", true},
+ {"Complex balanced", "(a(b)c(d(e)f)g)", true},
+ {"Unbalanced - too many open", "(()", false},
+ {"Unbalanced - too many closed", "())", false},
+ {"Unbalanced - wrong order", ")(", false},
+ {"No parentheses", "abc", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ result := hasBalancedParentheses(tt.input)
+ assert.Equal(t, tt.expected, result)
+ })
+ }
+}
+
+func TestLdapValidatorAddIssue(t *testing.T) {
+ validator := NewLdapValidator()
+
+ // Add a warning
+ validator.addIssue("TestField", "Test warning", ValidationWarning)
+
+ // Add an error
+ validator.addIssue("TestField2", "Test error", ValidationError)
+
+ // Check that issues were added
+ assert.Equal(t, 2, len(validator.issues))
+ assert.Equal(t, "TestField", validator.issues[0].Field)
+ assert.Equal(t, "Test warning", validator.issues[0].Message)
+ assert.Equal(t, ValidationWarning, validator.issues[0].Level)
+ assert.Equal(t, "TestField2", validator.issues[1].Field)
+ assert.Equal(t, "Test error", validator.issues[1].Message)
+ assert.Equal(t, ValidationError, validator.issues[1].Level)
+
+ // Check hasErrors
+ assert.Assert(t, validator.hasErrors())
+}
+
+func TestLdapValidatorValidateConsistency(t *testing.T) {
+ tests := []struct {
+ name string
+ config *LdapConfig
+ expectWarning bool
+ }{
+ {
+ name: "No warnings",
+ config: &LdapConfig{
+ useSsl: false,
+ Insecure: false,
+ Port: 389,
+ },
+ expectWarning: false,
+ },
+ {
+ name: "SSL with non-standard port",
+ config: &LdapConfig{
+ useSsl: true,
+ Insecure: false,
+ Port: 389,
+ },
+ expectWarning: true,
+ },
+ {
+ name: "SSL with insecure",
+ config: &LdapConfig{
+ useSsl: true,
+ Insecure: true,
+ Port: 636,
+ },
+ expectWarning: true,
+ },
+ {
+ name: "SSL with standard port",
+ config: &LdapConfig{
+ useSsl: true,
+ Insecure: false,
+ Port: 636,
+ },
+ expectWarning: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ validator := NewLdapValidator()
+ validator.validateConsistency(tt.config)
+
+ hasWarnings := len(validator.issues) > 0
+ assert.Equal(t, tt.expectWarning, hasWarnings)
+ })
+ }
+}
+
+// TestValidatePort tests the validatePort method with various inputs
+func TestValidatePort(t *testing.T) {
+ tests := []struct {
+ name string
+ port int
+ expectWarning bool
+ expectError bool
+ }{
+ {
+ name: "Valid port - LDAP",
+ port: 389,
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Valid port - LDAPS",
+ port: 636,
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Valid port - custom",
+ port: 1389,
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Invalid port - too low",
+ port: 0,
+ expectWarning: false,
+ expectError: true,
+ },
+ {
+ name: "Invalid port - too high",
+ port: 65536,
+ expectWarning: false,
+ expectError: true,
+ },
+ {
+ name: "Valid port - minimum",
+ port: 1,
+ expectWarning: false,
+ expectError: false,
+ },
+ {
+ name: "Valid port - maximum",
+ port: 65535,
+ expectWarning: false,
+ expectError: false,
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ validator := NewLdapValidator()
+ validator.validatePort(tt.port)
+
+ hasError := false
+ hasWarning := false
+
+ for _, issue := range validator.issues {
+ if issue.Field == "Port" {
+ switch issue.Level {
+ case ValidationError:
+ hasError = true
+ case ValidationWarning:
+ hasWarning = true
+ }
+ }
+ }
+
+ assert.Equal(t, tt.expectError, hasError, "Expected
error: %v, got: %v", tt.expectError, hasError)
+ assert.Equal(t, tt.expectWarning, hasWarning, "Expected
warning: %v, got: %v", tt.expectWarning, hasWarning)
+ })
+ }
+}
+
+// TestLogIssues tests the logIssues method
+func TestLogIssues(t *testing.T) {
+ validator := NewLdapValidator()
+
+ // Add a warning
+ validator.addIssue("TestField1", "Test warning message",
ValidationWarning)
+
+ // Add an error
+ validator.addIssue("TestField2", "Test error message", ValidationError)
+
+ // Call logIssues - we can't easily capture the log output in a unit
test,
+ // but we can at least verify it doesn't panic
+ validator.logIssues()
+
+ // Verify the issues are still present after logging
+ assert.Equal(t, 2, len(validator.issues))
+ assert.Equal(t, "TestField1", validator.issues[0].Field)
+ assert.Equal(t, ValidationWarning, validator.issues[0].Level)
+ assert.Equal(t, "Test warning message", validator.issues[0].Message)
+ assert.Equal(t, "TestField2", validator.issues[1].Field)
+ assert.Equal(t, ValidationError, validator.issues[1].Level)
+ assert.Equal(t, "Test error message", validator.issues[1].Message)
+}
+
+// TestValidateBindUserValue tests the validateBindUserValue function
+func TestValidateBindUserValue(t *testing.T) {
+ tests := []struct {
+ name string
+ value string
+ wantError bool
+ }{
+ {"Valid DN", "cn=admin,dc=example,dc=com", false},
+ {"Valid username", "admin", false},
+ {"Empty", "", true},
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, err := validateBindUserValue(tt.value)
+ if tt.wantError {
+ assert.Assert(t, err != nil)
+ } else {
+ assert.NilError(t, err)
+ }
+ })
+ }
+}
+
+// TestValidateBaseDNValueEdgeCases tests edge cases for validateBaseDNValue
+func TestValidateBaseDNValueEdgeCases(t *testing.T) {
+ // Test empty value
+ _, err := validateBaseDNValue("")
+ assert.Assert(t, err != nil)
+ assert.ErrorContains(t, err, "baseDN cannot be empty")
+}
+
+// TestValidateGroupAttrValueEdgeCases tests edge cases for
validateGroupAttrValue
+func TestValidateGroupAttrValueEdgeCases(t *testing.T) {
+ // Test empty value
+ _, err := validateGroupAttrValue("")
+ assert.Assert(t, err != nil)
+ assert.ErrorContains(t, err, "groupAttr cannot be empty")
+}
+
+// TestValidateReturnAttrValueEdgeCases tests edge cases for
validateReturnAttrValue
+func TestValidateReturnAttrValueEdgeCases(t *testing.T) {
+ // Test empty value
+ _, err := validateReturnAttrValue("")
+ assert.Assert(t, err != nil)
+ assert.ErrorContains(t, err, "returnAttr cannot be empty")
+}
+
+// TestValidateBindPasswordValueEdgeCases tests edge cases for
validateBindPasswordValue
+func TestValidateBindPasswordValueEdgeCases(t *testing.T) {
+ // Test empty value
+ _, err := validateBindPasswordValue("")
+ assert.Assert(t, err != nil)
+ assert.ErrorContains(t, err, "bindPassword cannot be empty")
+}
diff --git a/pkg/common/security/usergroup.go b/pkg/common/security/usergroup.go
index d9a1966c..eae10a44 100644
--- a/pkg/common/security/usergroup.go
+++ b/pkg/common/security/usergroup.go
@@ -66,20 +66,32 @@ type UserGroup struct {
resolved int64
}
+const (
+ Default = ""
+ Ldap = "ldap"
+ Test = "test"
+ Os = "os"
+)
+
// Get the resolver for the user and group info.
// Current setup allows three resolvers:
// * NO resolver: default, no user or group resolution just return the info
(k8s use case)
// * OS resolver: uses the OS libraries to resolve user and group memberships
// * Test resolver: fake resolution for testing
-func GetUserGroupCache(resolver string) *UserGroupCache {
+// * Ldap resolver: uses the LDAP protocol to resolve user and group
memberships
+func GetUserGroupCache(ugr configs.UserGroupResolver, ldapConfigReader
ConfigReader, ldapAccess LdapAccess) *UserGroupCache {
+ resolver := ugr.Type
once.Do(func() {
switch resolver {
- case "test":
+ case Test:
log.Log(log.Security).Info("creating test user group
resolver")
instance = GetUserGroupCacheTest()
- case "os":
+ case Os:
log.Log(log.Security).Info("creating OS user group
resolver")
instance = GetUserGroupCacheOS()
+ case Ldap:
+ log.Log(log.Security).Info("creating LDAP user group
resolver")
+ instance = GetUserGroupCacheLdap(ldapConfigReader,
ldapAccess)
default:
log.Log(log.Security).Info("creating UserGroupCache
without resolver")
instance = GetUserGroupNoResolve()
@@ -231,6 +243,10 @@ func (c *UserGroupCache) Stop() {
if !stopped.Load() {
log.Log(log.Security).Info("Stopping UserGroupCache background
cleanup")
close(c.stop)
+ // Clear the cache before resetting the instance
+ c.lock.Lock()
+ c.ugs = make(map[string]*UserGroup)
+ c.lock.Unlock()
once = &sync.Once{} // re-init so that GetUserGroupCache() can
create a new instance again
instance = nil
stopped.Store(true)
diff --git a/pkg/common/security/usergroup_ldap_resolver.go
b/pkg/common/security/usergroup_ldap_resolver.go
new file mode 100644
index 00000000..eea68b44
--- /dev/null
+++ b/pkg/common/security/usergroup_ldap_resolver.go
@@ -0,0 +1,383 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package security
+
+import (
+ "crypto/tls"
+ "fmt"
+ "os"
+ "os/user"
+ "path/filepath"
+ "strings"
+ "time"
+
+ "github.com/go-ldap/ldap/v3"
+ "go.uber.org/zap"
+
+ "github.com/apache/yunikorn-core/pkg/common"
+ "github.com/apache/yunikorn-core/pkg/log"
+)
+
+// This file contains the implementation of the LDAP resolver for user groups
+
+type LdapLookup struct {
+ config LdapConfig
+ access LdapAccess
+}
+
+// LdapAccess defines the interface for LDAP operations
+type LdapAccess interface {
+ // DialURL establishes a connection to the LDAP server
+ DialURL(url string, options ...ldap.DialOpt) (*ldap.Conn, error)
+
+ // Bind authenticates with the LDAP server
+ Bind(conn *ldap.Conn, username, password string) error
+
+ // Search performs an LDAP search operation
+ Search(conn *ldap.Conn, searchRequest *ldap.SearchRequest)
(*ldap.SearchResult, error)
+
+ // Close closes the LDAP connection
+ Close(conn *ldap.Conn)
+}
+
+type ConfigReader interface {
+ ReadLdapConfig() (*LdapConfig, error)
+}
+
+type configReaderImpl struct{}
+
+func (configReaderImpl) ReadLdapConfig() (*LdapConfig, error) {
+ secretsDir := common.LdapMountPath
+
+ // Read all files from secrets directory
+ files, err := os.ReadDir(secretsDir)
+ if err != nil {
+ log.Log(log.Security).Error("Unable to access LDAP secrets
directory",
+ zap.String("directory", secretsDir),
+ zap.Error(err))
+ return nil, fmt.Errorf("unable to access LDAP secrets directory
under %s", secretsDir)
+ }
+
+ secretCount := 0
+ validSecrets := make(map[string]interface{})
+
+ // Iterate over all secret files in the secrets directory
+ for _, file := range files {
+ fileName := file.Name()
+
+ // Skip non-secret entries such as Kubernetes internal metadata
(e.g., symlinks like "..data" or directories like "..timestamp")
+ if strings.HasPrefix(fileName, "..") || file.IsDir() {
+ log.Log(log.Security).Info("Ignoring non-secret entry
(Kubernetes metadata entry or directory)",
+ zap.String("name", fileName))
+ continue
+ }
+
+ secretKey := fileName
+ secretValueBytes, err := os.ReadFile(filepath.Join(secretsDir,
secretKey))
+ if err != nil {
+ log.Log(log.Security).Warn("Could not read secret file",
+ zap.String("file", secretKey),
+ zap.Error(err))
+ continue
+ }
+ secretValue := strings.TrimSpace(string(secretValueBytes))
+
+ // Validate the secret value
+ validatedValue, err := ValidateSecretValue(secretKey,
secretValue)
+ if err != nil {
+ log.Log(log.Security).Warn("Invalid LDAP secret value",
+ zap.String("key", secretKey),
+ zap.Error(err))
+ continue
+ }
+
+ // Store the validated value
+ validSecrets[secretKey] = validatedValue
+ secretCount++
+
+ log.Log(log.Security).Debug("Loaded LDAP secret",
+ zap.String("key", secretKey))
+ }
+
+ ldapConf := getDefaultLdapConfig()
+
+ // Apply validated values to the configuration
+ if host, ok := validSecrets[common.LdapHost].(string); ok {
+ ldapConf.Host = host
+ }
+ if port, ok := validSecrets[common.LdapPort].(int); ok {
+ ldapConf.Port = port
+ }
+ if baseDN, ok := validSecrets[common.LdapBaseDN].(string); ok {
+ ldapConf.BaseDN = baseDN
+ }
+ if filter, ok := validSecrets[common.LdapFilter].(string); ok {
+ ldapConf.Filter = filter
+ }
+ if groupAttr, ok := validSecrets[common.LdapGroupAttr].(string); ok {
+ ldapConf.GroupAttr = groupAttr
+ }
+ if returnAttr, ok := validSecrets[common.LdapReturnAttr].([]string); ok
{
+ ldapConf.ReturnAttr = returnAttr
+ }
+ if bindUser, ok := validSecrets[common.LdapBindUser].(string); ok {
+ ldapConf.BindUser = bindUser
+ }
+ if bindPassword, ok := validSecrets[common.LdapBindPassword].(string);
ok {
+ ldapConf.BindPassword = bindPassword
+ }
+ if insecure, ok := validSecrets[common.LdapInsecure].(bool); ok {
+ ldapConf.Insecure = insecure
+ }
+ if ssl, ok := validSecrets[common.LdapSSL].(bool); ok {
+ ldapConf.useSsl = ssl
+ }
+
+ // Validate the entire configuration
+ validator := NewLdapValidator()
+ isValid := validator.ValidateConfig(ldapConf)
+
+ // Check if all required fields were provided in the secrets
+ requiredFields := []string{
+ common.LdapHost,
+ common.LdapPort,
+ common.LdapBaseDN,
+ common.LdapFilter,
+ common.LdapGroupAttr,
+ common.LdapReturnAttr,
+ common.LdapBindUser,
+ common.LdapBindPassword,
+ }
+
+ var missingFields []string
+ for _, field := range requiredFields {
+ if _, ok := validSecrets[field]; !ok {
+ missingFields = append(missingFields, field)
+ }
+ }
+
+ if len(missingFields) > 0 {
+ log.Log(log.Security).Error("Missing required LDAP
configuration fields",
+ zap.Strings("missingFields", missingFields))
+ isValid = false
+ }
+
+ log.Log(log.Security).Info("Finished loading LDAP secrets",
+ zap.Int("numberOfSecretsLoaded", secretCount),
+ zap.Bool("configurationValid", isValid),
+ zap.Int("missingRequiredFields", len(missingFields)))
+
+ if secretCount == 0 || !isValid || len(missingFields) != 0 {
+ return ldapConf, fmt.Errorf("unable to properly load LDAP
configuration")
+ }
+
+ return ldapConf, nil
+}
+
+func GetConfigReader() ConfigReader {
+ return configReaderImpl{}
+}
+
+func getDefaultLdapConfig() *LdapConfig {
+ return &LdapConfig{
+ Host: common.DefaultLdapHost,
+ Port: common.DefaultLdapPort,
+ BaseDN: common.DefaultLdapBaseDN,
+ Filter: common.DefaultLdapFilter,
+ GroupAttr: common.DefaultLdapGroupAttr,
+ ReturnAttr: common.DefaultLdapReturnAttr,
+ BindUser: common.DefaultLdapBindUser,
+ BindPassword: common.DefaultLdapBindPassword,
+ Insecure: common.DefaultLdapInsecure,
+ useSsl: common.DefaultLdapSSL,
+ }
+}
+
+// ldapAccessImpl implements the LdapAccess interface with real LDAP operations
+type ldapAccessImpl struct{}
+
+func (ldapAccessImpl) DialURL(url string, options ...ldap.DialOpt)
(*ldap.Conn, error) {
+ return ldap.DialURL(url, options...)
+}
+
+func (ldapAccessImpl) Bind(conn *ldap.Conn, username, password string) error {
+ return conn.Bind(username, password)
+}
+
+func (ldapAccessImpl) Search(conn *ldap.Conn, searchRequest
*ldap.SearchRequest) (*ldap.SearchResult, error) {
+ return conn.Search(searchRequest)
+}
+
+func (ldapAccessImpl) Close(conn *ldap.Conn) {
+ _ = conn.Close()
+}
+
+func GetLdapAccess() LdapAccess {
+ return ldapAccessImpl{}
+}
+
+// LDAPResolverConfig holds the configuration for the LDAP resolver
+type LdapConfig struct {
+ Host string
+ Port int
+ BaseDN string
+ Filter string
+ GroupAttr string
+ ReturnAttr []string
+ BindUser string
+ BindPassword string
+ Insecure bool
+ useSsl bool
+}
+
+func GetUserGroupCacheLdap(reader ConfigReader, access LdapAccess)
*UserGroupCache {
+ config, err := reader.ReadLdapConfig()
+ if err != nil {
+ // Log a FATAL level message - this is very prominent and will
typically cause the application to exit
+ log.Log(log.Security).Fatal("LDAP configuration not found or
invalid. No secrets were loaded from the secrets directory.",
+ zap.String("secretsPath", common.LdapMountPath),
+ zap.String("resolution", "Ensure LDAP secrets are
properly mounted and accessible"))
+
+ // If the Fatal log doesn't cause an exit (depends on logger
configuration),
+ // we could also panic here to ensure the application stops
+ panic("LDAP configuration not found or invalid")
+ }
+
+ ldapLookup := &LdapLookup{
+ config: *config,
+ access: access,
+ }
+
+ return &UserGroupCache{
+ ugs: map[string]*UserGroup{},
+ interval: cleanerInterval * time.Second,
+ lookup: ldapLookup.LdapLookupUser,
+ lookupGroupID: ldapLookup.LdapLookupGroupID,
+ groupIds: ldapLookup.LDAPLookupGroupIds,
+ stop: make(chan struct{}),
+ }
+}
+
+// Default linux behaviour: a user is member of the primary group with the
same name
+func (LdapLookup) LdapLookupUser(userName string) (*user.User, error) {
+ log.Log(log.Security).Debug("Performing LDAP user lookup",
+ zap.String("username", userName),
+ zap.String("defaultUID", common.DefaultLdapUserUID))
+ return &user.User{
+ Uid: common.DefaultLdapUserUID,
+ Gid: userName,
+ Username: userName,
+ }, nil
+}
+
+func (LdapLookup) LdapLookupGroupID(gid string) (*user.Group, error) {
+ log.Log(log.Security).Debug("Looking up LDAP group ID",
+ zap.String("groupID", gid))
+ group := user.Group{Gid: gid}
+ group.Name = gid
+ return &group, nil
+}
+
+func (lu LdapLookup) LDAPLookupGroupIds(osUser *user.User) ([]string, error) {
+ sr, err := ldapSearch(lu.access, lu.config, osUser.Username)
+ if err != nil {
+ log.Log(log.Security).Error("Failed to connect to LDAP for
group lookup",
+ zap.String("user", osUser.Username),
+ zap.Error(err))
+ return nil, err
+ }
+
+ var groups []string
+ for _, entry := range sr.Entries {
+ attr := entry.GetAttributeValues("memberOf")
+ log.Log(log.Security).Debug("LDAP 'memberOf' attributes for
user",
+ zap.String("user", osUser.Username),
+ zap.Strings("attributes", attr))
+ for i := range attr {
+ s := strings.Split(attr[i], ",")
+ newgroup := strings.Split(s[0], "CN=")
+ groups = append(groups, newgroup[1])
+ }
+ }
+ return groups, nil
+}
+
+// ldapSearch performs an LDAP search for the specified username
+// This replaces the old LDAPConn_Bind function with a more testable approach
+func ldapSearch(ldapAccess LdapAccess, ldapConf LdapConfig, userName string)
(*ldap.SearchResult, error) {
+ var ldapUri string
+ if ldapConf.useSsl {
+ ldapUri = "ldaps"
+ } else {
+ ldapUri = "ldap"
+ }
+
+ ldapaddr := fmt.Sprintf("%s://%s:%d", ldapUri, ldapConf.Host,
ldapConf.Port)
+ log.Log(log.Security).Debug("Attempting LDAP connection",
+ zap.String("address", ldapaddr),
+ zap.Bool("ssl", ldapConf.useSsl),
+ zap.Bool("insecureSkipVerify", ldapConf.Insecure))
+
+ l, err := ldapAccess.DialURL(ldapaddr,
+ ldap.DialWithTLSConfig(&tls.Config{InsecureSkipVerify:
ldapConf.Insecure})) // #nosec G402
+ if err != nil {
+ log.Log(log.Security).Error("Error connecting to LDAP server",
+ zap.String("address", ldapaddr),
+ zap.Error(err))
+ return nil, err
+ }
+ defer ldapAccess.Close(l)
+
+ log.Log(log.Security).Debug("LDAP connection successful, attempting
bind",
+ zap.String("bindUser", ldapConf.BindUser))
+ err = ldapAccess.Bind(l, ldapConf.BindUser, ldapConf.BindPassword)
+ if err != nil {
+ log.Log(log.Security).Error("Failed to bind with LDAP server",
+ zap.String("bindDN", ldapConf.BindUser),
+ zap.Error(err))
+ return nil, err
+ }
+
+ filter := fmt.Sprintf(ldapConf.Filter, userName)
+ log.Log(log.Security).Debug("Executing LDAP search",
+ zap.String("baseDN", ldapConf.BaseDN),
+ zap.String("filter", filter),
+ zap.Strings("attributesToReturn", ldapConf.ReturnAttr))
+
+ searchRequest := ldap.NewSearchRequest(
+ ldapConf.BaseDN,
+ ldap.ScopeWholeSubtree, ldap.NeverDerefAliases, 0, 0, false,
+ filter,
+ ldapConf.ReturnAttr,
+ nil,
+ )
+ sr, err := ldapAccess.Search(l, searchRequest)
+ if err != nil {
+ log.Log(log.Security).Error("Failed to execute LDAP search
query",
+ zap.String("filter", filter),
+ zap.String("baseDN", ldapConf.BaseDN),
+ zap.Error(err))
+ return nil, err
+ }
+
+ log.Log(log.Security).Debug("LDAP search completed successfully",
+ zap.String("username", userName),
+ zap.Int("entriesFound", len(sr.Entries)))
+ return sr, nil
+}
diff --git a/pkg/common/security/usergroup_ldap_resolver_test.go
b/pkg/common/security/usergroup_ldap_resolver_test.go
new file mode 100644
index 00000000..17f28cf8
--- /dev/null
+++ b/pkg/common/security/usergroup_ldap_resolver_test.go
@@ -0,0 +1,708 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package security
+
+import (
+ "errors"
+ "fmt"
+ "os"
+ "os/user"
+ "path/filepath"
+ "strconv"
+ "strings"
+ "testing"
+ "time"
+
+ "github.com/go-ldap/ldap/v3"
+ "gotest.tools/v3/assert"
+
+ "github.com/apache/yunikorn-core/pkg/common"
+)
+
+// Mock LDAP search result for testing
+func mockLdapSearchResult(username string) (*ldap.SearchResult, error) {
+ if username == Testuser1 || username == Testuser {
+ return &ldap.SearchResult{
+ Entries: []*ldap.Entry{
+ {
+ Attributes: []*ldap.EntryAttribute{
+ {
+ Name: "memberOf",
+ Values:
[]string{"CN=group1001,OU=groups,DC=example,DC=com"},
+ },
+ },
+ },
+ },
+ }, nil
+ }
+ if username == Testuser2 {
+ return &ldap.SearchResult{
+ Entries: []*ldap.Entry{
+ {
+ Attributes: []*ldap.EntryAttribute{
+ {
+ Name: "memberOf",
+ Values:
[]string{"CN=group1001,OU=groups,DC=example,DC=com",
"CN=group1002,OU=groups,DC=example,DC=com"},
+ },
+ },
+ },
+ },
+ }, nil
+ }
+ if username == Testuser3 {
+ return &ldap.SearchResult{
+ Entries: []*ldap.Entry{
+ {
+ Attributes: []*ldap.EntryAttribute{
+ {
+ Name: "memberOf",
+ Values:
[]string{"CN=group1002,OU=groups,DC=example,DC=com",
"CN=group1001,OU=groups,DC=example,DC=com",
"CN=group1003,OU=groups,DC=example,DC=com",
"CN=group1004,OU=groups,DC=example,DC=com"},
+ },
+ },
+ },
+ },
+ }, nil
+ }
+ if username == Testuser4 {
+ return &ldap.SearchResult{
+ Entries: []*ldap.Entry{
+ {
+ Attributes: []*ldap.EntryAttribute{
+ {
+ Name: "memberOf",
+ Values:
[]string{"CN=group901,OU=groups,DC=example,DC=com",
"CN=group902,OU=groups,DC=example,DC=com"},
+ },
+ },
+ },
+ },
+ }, nil
+ }
+ return nil, fmt.Errorf("ldap lookup failed for user: %s", username)
+}
+
+// LdapAccessMock implements the LdapAccess interface for testing
+type LdapAccessMock struct {
+ DialURLFunc func(url string, options ...ldap.DialOpt) (*ldap.Conn,
error)
+ BindFunc func(conn *ldap.Conn, username, password string) error
+ SearchFunc func(conn *ldap.Conn, searchRequest *ldap.SearchRequest)
(*ldap.SearchResult, error)
+ CloseFunc func(conn *ldap.Conn)
+ SearchResult *ldap.SearchResult
+ Error error
+}
+
+type ConfigReaderMock struct{}
+
+func (ConfigReaderMock) ReadLdapConfig() (*LdapConfig, error) {
+ return &LdapConfig{}, nil
+}
+
+func (m *LdapAccessMock) DialURL(url string, options ...ldap.DialOpt)
(*ldap.Conn, error) {
+ if m.DialURLFunc != nil {
+ return m.DialURLFunc(url, options...)
+ }
+ return &ldap.Conn{}, nil
+}
+
+func (m *LdapAccessMock) Bind(conn *ldap.Conn, username, password string)
error {
+ if m.BindFunc != nil {
+ return m.BindFunc(conn, username, password)
+ }
+ return nil
+}
+
+func (m *LdapAccessMock) Search(conn *ldap.Conn, searchRequest
*ldap.SearchRequest) (*ldap.SearchResult, error) {
+ if m.SearchFunc != nil {
+ return m.SearchFunc(conn, searchRequest)
+ }
+ return m.SearchResult, m.Error
+}
+
+func (m *LdapAccessMock) Close(conn *ldap.Conn) {
+ if m.CloseFunc != nil {
+ m.CloseFunc(conn)
+ }
+}
+
+// Helper function to create a mock LDAP access with predefined search results
+func newMockLdapAccess(searchResult *ldap.SearchResult, err error)
*LdapAccessMock {
+ return &LdapAccessMock{
+ SearchResult: searchResult,
+ Error: err,
+ }
+}
+
+// TestLdapSearch tests the new ldapSearch function with a mock LdapAccess
+func TestLdapSearch(t *testing.T) {
+ // Create a mock search result
+ mockResult := &ldap.SearchResult{
+ Entries: []*ldap.Entry{
+ {
+ Attributes: []*ldap.EntryAttribute{
+ {
+ Name: "memberOf",
+ Values:
[]string{"CN=group1,OU=groups,DC=example,DC=com",
"CN=group2,OU=groups,DC=example,DC=com"},
+ },
+ },
+ },
+ },
+ }
+
+ // Create a mock LDAP access with the mock result
+ mockAccess := newMockLdapAccess(mockResult, nil)
+ savedUrl := ""
+ mockAccess.DialURLFunc = func(url string, _ ...ldap.DialOpt)
(*ldap.Conn, error) {
+ savedUrl = url
+ return &ldap.Conn{}, nil
+ }
+
+ // Call ldapSearch with the mock access
+ ldapConf := LdapConfig{
+ Host: "testhost",
+ Port: 1234,
+ }
+ result, err := ldapSearch(mockAccess, ldapConf, "testuser")
+
+ // Verify results
+ assert.NilError(t, err)
+ assert.Assert(t, result != nil)
+ assert.Equal(t, 1, len(result.Entries))
+ assert.Equal(t, 1, len(result.Entries[0].Attributes))
+ assert.Equal(t, "memberOf", result.Entries[0].Attributes[0].Name)
+ assert.Equal(t, 2, len(result.Entries[0].Attributes[0].Values))
+ assert.Equal(t, "CN=group1,OU=groups,DC=example,DC=com",
result.Entries[0].Attributes[0].Values[0])
+ assert.Equal(t, "CN=group2,OU=groups,DC=example,DC=com",
result.Entries[0].Attributes[0].Values[1])
+ assert.Equal(t, "ldap://testhost:1234", savedUrl)
+
+ // check useSsl
+ ldapConf = LdapConfig{
+ Host: "testhost",
+ Port: 1234,
+ useSsl: true,
+ }
+ _, err = ldapSearch(mockAccess, ldapConf, "testuser")
+ assert.NilError(t, err)
+ assert.Equal(t, "ldaps://testhost:1234", savedUrl)
+}
+
+// TestLdapSearchError tests the error handling in ldapSearch
+func TestLdapSearchError(t *testing.T) {
+ // Test cases for different error scenarios
+ testCases := []struct {
+ name string
+ dialError error
+ bindError error
+ searchError error
+ }{
+ {
+ name: "Dial Error",
+ dialError: errors.New("dial error"),
+ bindError: nil,
+ searchError: nil,
+ },
+ {
+ name: "Bind Error",
+ dialError: nil,
+ bindError: errors.New("bind error"),
+ searchError: nil,
+ },
+ {
+ name: "Search Error",
+ dialError: nil,
+ bindError: nil,
+ searchError: errors.New("search error"),
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Create a mock LDAP access with the appropriate error
+ mockAccess := &LdapAccessMock{
+ DialURLFunc: func(url string, options
...ldap.DialOpt) (*ldap.Conn, error) {
+ return &ldap.Conn{}, tc.dialError
+ },
+ BindFunc: func(conn *ldap.Conn, username,
password string) error {
+ return tc.bindError
+ },
+ SearchFunc: func(conn *ldap.Conn, searchRequest
*ldap.SearchRequest) (*ldap.SearchResult, error) {
+ return nil, tc.searchError
+ },
+ }
+
+ // Call ldapSearch with the mock access
+ result, err := ldapSearch(mockAccess, LdapConfig{},
"testuser")
+
+ // Verify error
+ assert.Assert(t, err != nil)
+ assert.Assert(t, result == nil)
+
+ // Check for specific error
+ switch {
+ case tc.dialError != nil:
+ assert.Equal(t, tc.dialError.Error(),
err.Error())
+ case tc.bindError != nil:
+ assert.Equal(t, tc.bindError.Error(),
err.Error())
+ case tc.searchError != nil:
+ assert.Equal(t, tc.searchError.Error(),
err.Error())
+ }
+ })
+ }
+}
+
+func TestLdapLookups(t *testing.T) {
+ tests := []struct {
+ name string
+ testType string
+ id string
+ validate func(t *testing.T, result interface{}, err error)
+ }{
+ {
+ name: "Lookup user",
+ testType: "user",
+ id: "testuser",
+ validate: func(t *testing.T, result interface{}, err
error) {
+ assert.NilError(t, err)
+ u, ok := result.(*user.User)
+ assert.Assert(t, ok, "invalid result type")
+ assert.Equal(t, "testuser", u.Username)
+ assert.Equal(t, "testuser", u.Gid)
+ assert.Equal(t, "1211", u.Uid)
+ },
+ },
+ {
+ name: "Lookup group",
+ testType: "group",
+ id: "testgroup",
+ validate: func(t *testing.T, result interface{}, err
error) {
+ assert.NilError(t, err)
+ g, ok := result.(*user.Group)
+ assert.Assert(t, ok, "invalid result type")
+ assert.Equal(t, "testgroup", g.Gid)
+ assert.Equal(t, "testgroup", g.Name)
+ },
+ },
+ }
+
+ lu := &LdapLookup{}
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ switch tt.testType {
+ case "user":
+ u, err := lu.LdapLookupUser(tt.id)
+ tt.validate(t, u, err)
+ case "group":
+ g, err := lu.LdapLookupGroupID(tt.id)
+ tt.validate(t, g, err)
+ }
+ })
+ }
+}
+
+func TestLDAPLookupGroupIds(t *testing.T) {
+ // Create a mock search result
+ mockResult := &ldap.SearchResult{
+ Entries: []*ldap.Entry{
+ {
+ Attributes: []*ldap.EntryAttribute{
+ {
+ Name: "memberOf",
+ Values:
[]string{"CN=group1,OU=groups,DC=example,DC=com",
"CN=group2,OU=groups,DC=example,DC=com"},
+ },
+ },
+ },
+ },
+ }
+
+ u := &user.User{Username: "testuser"}
+ lu := &LdapLookup{
+ access: newMockLdapAccess(mockResult, nil),
+ config: LdapConfig{},
+ }
+
+ groups, err := lu.LDAPLookupGroupIds(u)
+ assert.NilError(t, err)
+ assert.Assert(t, strings.Contains(strings.Join(groups, ","), "group1"))
+ assert.Assert(t, strings.Contains(strings.Join(groups, ","), "group2"))
+}
+
+func TestLDAPLookupGroupIdsError(t *testing.T) {
+ u := &user.User{Username: "testuser"}
+ lu := &LdapLookup{
+ access: newMockLdapAccess(nil, errors.New("ldap error")),
+ config: LdapConfig{},
+ }
+ groups, err := lu.LDAPLookupGroupIds(u)
+ assert.Error(t, err, "ldap error")
+ assert.Assert(t, groups == nil)
+}
+
+//nolint:funlen // Table-driven test for coverage, helpers used to reduce
length
+func TestReadSecrets(t *testing.T) {
+ tests := getReadSecretsTestCases()
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ _, cleanup := tt.setupFunc(t)
+ defer cleanup()
+ reader := &configReaderImpl{}
+ ldapConf, err := reader.ReadLdapConfig()
+ assert.Equal(t, tt.expectedResult, err == nil)
+ if tt.nilConf && ldapConf == nil {
+ return
+ }
+ tt.validateFunc(t, ldapConf)
+ })
+ }
+}
+
+//nolint:funlen // Table-driven test helper for coverage, intentionally long
+func getReadSecretsTestCases() []struct {
+ name string
+ setupFunc func(t *testing.T) (string, func())
+ expectedResult bool
+ nilConf bool
+ validateFunc func(t *testing.T, conf *LdapConfig)
+} {
+ return []struct {
+ name string
+ setupFunc func(t *testing.T) (string, func())
+ expectedResult bool
+ nilConf bool
+ validateFunc func(t *testing.T, conf *LdapConfig)
+ }{
+ {
+ name: "Skips K8s metadata and directories",
+ setupFunc: func(t *testing.T) (string, func()) {
+ tmpDir := t.TempDir()
+ err := os.Mkdir(filepath.Join(tmpDir,
"..data"), 0755)
+ assert.NilError(t, err)
+ err = os.Mkdir(filepath.Join(tmpDir, "dir1"),
0755)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
"key1"), []byte("value1"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
"..timestamp"), []byte("meta"), 0600)
+ assert.NilError(t, err)
+ origLdapMountPath := common.LdapMountPath
+ common.LdapMountPath = tmpDir
+ return tmpDir, func() { common.LdapMountPath =
origLdapMountPath }
+ },
+ expectedResult: false,
+ validateFunc: func(t *testing.T, ldapConf *LdapConfig) {
+ assert.Equal(t, common.DefaultLdapHost,
ldapConf.Host)
+ assert.Equal(t, common.DefaultLdapPort,
ldapConf.Port)
+ assert.Equal(t, common.DefaultLdapBaseDN,
ldapConf.BaseDN)
+ assert.Equal(t, common.DefaultLdapFilter,
ldapConf.Filter)
+ assert.Equal(t, common.DefaultLdapGroupAttr,
ldapConf.GroupAttr)
+ assert.Equal(t,
strings.Join(common.DefaultLdapReturnAttr, ","),
strings.Join(ldapConf.ReturnAttr, ","))
+ assert.Equal(t, common.DefaultLdapBindUser,
ldapConf.BindUser)
+ assert.Equal(t, common.DefaultLdapBindPassword,
ldapConf.BindPassword)
+ assert.Equal(t, common.DefaultLdapInsecure,
ldapConf.Insecure)
+ assert.Equal(t, common.DefaultLdapSSL,
ldapConf.useSsl)
+ },
+ },
+ {
+ name: "Handles missing secrets directory",
+ setupFunc: func(t *testing.T) (string, func()) {
+ origLdapMountPath := common.LdapMountPath
+ common.LdapMountPath = "/nonexistent"
+ return "/nonexistent", func() {
common.LdapMountPath = origLdapMountPath }
+ },
+ expectedResult: false,
+ nilConf: true,
+ validateFunc: func(t *testing.T, ldapConf
*LdapConfig) {},
+ },
+ {
+ name: "Handles unknown key",
+ setupFunc: func(t *testing.T) (string, func()) {
+ tmpDir := t.TempDir()
+ err := os.WriteFile(filepath.Join(tmpDir,
"unknownKey"), []byte("somevalue"), 0600)
+ assert.NilError(t, err)
+ origLdapMountPath := common.LdapMountPath
+ common.LdapMountPath = tmpDir
+ return tmpDir, func() { common.LdapMountPath =
origLdapMountPath }
+ },
+ expectedResult: false,
+ validateFunc: func(t *testing.T, ldapConf *LdapConfig) {
+ assert.Equal(t, common.DefaultLdapHost,
ldapConf.Host)
+ assert.Equal(t, common.DefaultLdapPort,
ldapConf.Port)
+ assert.Equal(t, common.DefaultLdapBaseDN,
ldapConf.BaseDN)
+ assert.Equal(t, common.DefaultLdapFilter,
ldapConf.Filter)
+ assert.Equal(t, common.DefaultLdapGroupAttr,
ldapConf.GroupAttr)
+ assert.Equal(t,
strings.Join(common.DefaultLdapReturnAttr, ","),
strings.Join(ldapConf.ReturnAttr, ","))
+ assert.Equal(t, common.DefaultLdapBindUser,
ldapConf.BindUser)
+ assert.Equal(t, common.DefaultLdapBindPassword,
ldapConf.BindPassword)
+ assert.Equal(t, common.DefaultLdapInsecure,
ldapConf.Insecure)
+ assert.Equal(t, common.DefaultLdapSSL,
ldapConf.useSsl)
+ },
+ },
+ {
+ name: "Handles invalid port and bool values",
+ setupFunc: func(t *testing.T) (string, func()) {
+ tmpDir := t.TempDir()
+ err := os.WriteFile(filepath.Join(tmpDir,
common.LdapPort), []byte("notanint"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapInsecure), []byte("notabool"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapSSL), []byte("notabool"), 0600)
+ assert.NilError(t, err)
+ origLdapMountPath := common.LdapMountPath
+ common.LdapMountPath = tmpDir
+ return tmpDir, func() { common.LdapMountPath =
origLdapMountPath }
+ },
+ expectedResult: false,
+ validateFunc: func(t *testing.T, ldapConf *LdapConfig) {
+ // Assert that ldapConf.Port is set to
DefaultLdapPort when invalid int value is provided
+ assert.Equal(t, common.DefaultLdapPort,
ldapConf.Port)
+
+ // Assert that rest of ldap conf is set to
default values
+ assert.Equal(t, common.DefaultLdapHost,
ldapConf.Host)
+ assert.Equal(t, common.DefaultLdapBaseDN,
ldapConf.BaseDN)
+ assert.Equal(t, common.DefaultLdapFilter,
ldapConf.Filter)
+ assert.Equal(t, common.DefaultLdapGroupAttr,
ldapConf.GroupAttr)
+ assert.Equal(t,
strings.Join(common.DefaultLdapReturnAttr, ","),
strings.Join(ldapConf.ReturnAttr, ","))
+ assert.Equal(t, common.DefaultLdapBindUser,
ldapConf.BindUser)
+ assert.Equal(t, common.DefaultLdapBindPassword,
ldapConf.BindPassword)
+ assert.Equal(t, common.DefaultLdapInsecure,
ldapConf.Insecure)
+ assert.Equal(t, common.DefaultLdapSSL,
ldapConf.useSsl)
+ },
+ },
+ {
+ name: "Sets custom values",
+ setupFunc: func(t *testing.T) (string, func()) {
+ tmpDir := t.TempDir()
+ err := os.WriteFile(filepath.Join(tmpDir,
common.LdapHost), []byte("myhost"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapPort), []byte("1234"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapBaseDN), []byte("dc=test,dc=com"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapFilter), []byte("(&(uid=%s))"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapGroupAttr), []byte("groups"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapReturnAttr), []byte("memberOf,groups"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapBindUser), []byte("binduser"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapBindPassword), []byte("bindpass"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapInsecure), []byte("true"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapSSL), []byte("true"), 0600)
+ assert.NilError(t, err)
+ origLdapMountPath := common.LdapMountPath
+ common.LdapMountPath = tmpDir
+ return tmpDir, func() { common.LdapMountPath =
origLdapMountPath }
+ },
+ expectedResult: true,
+ validateFunc: func(t *testing.T, ldapConf *LdapConfig) {
+ assert.Equal(t, "myhost", ldapConf.Host)
+
+ // Use strconv to verify the port value to
ensure the import is used
+ portStr := "1234"
+ expectedPort, err := strconv.Atoi(portStr)
+ assert.NilError(t, err, "failed to convert port
string to int")
+ assert.Equal(t, expectedPort, ldapConf.Port)
+
+ assert.Equal(t, "dc=test,dc=com",
ldapConf.BaseDN)
+ assert.Equal(t, "(&(uid=%s))", ldapConf.Filter)
+ assert.Equal(t, "groups", ldapConf.GroupAttr)
+ assert.Equal(t, "memberOf,groups",
strings.Join(ldapConf.ReturnAttr, ","))
+ assert.Equal(t, "binduser", ldapConf.BindUser)
+ assert.Equal(t, "bindpass",
ldapConf.BindPassword)
+
+ // Use strconv to verify boolean values
+ insecureStr := "true"
+ expectedInsecure, err :=
strconv.ParseBool(insecureStr)
+ assert.NilError(t, err, "failed to convert
insecure string to bool")
+ assert.Equal(t, expectedInsecure,
ldapConf.Insecure)
+
+ sslStr := "true"
+ expectedSSL, err := strconv.ParseBool(sslStr)
+ assert.NilError(t, err, "failed to convert ssl
string to bool")
+ assert.Equal(t, expectedSSL, ldapConf.useSsl)
+ },
+ },
+ {
+ name: "Missing required fields",
+ setupFunc: func(t *testing.T) (string, func()) {
+ tmpDir := t.TempDir()
+ err := os.WriteFile(filepath.Join(tmpDir,
common.LdapHost), []byte("ldap.example.com"), 0600)
+ assert.NilError(t, err)
+ err = os.WriteFile(filepath.Join(tmpDir,
common.LdapPort), []byte("389"), 0600)
+ assert.NilError(t, err)
+ // Missing BaseDN, Filter, GroupAttr,
ReturnAttr, BindUser, BindPassword
+ origLdapMountPath := common.LdapMountPath
+ common.LdapMountPath = tmpDir
+ return tmpDir, func() { common.LdapMountPath =
origLdapMountPath }
+ },
+ expectedResult: false,
+ validateFunc: func(_ *testing.T, _ *LdapConfig) {
+ // No specific validation needed - we're
testing the return value
+ },
+ },
+ {
+ name: "All required fields present",
+ setupFunc: func(t *testing.T) (string, func()) {
+ tmpDir := t.TempDir()
+ requiredFields := map[string]string{
+ common.LdapHost:
"ldap.example.com",
+ common.LdapPort: "389",
+ common.LdapBaseDN:
"dc=example,dc=com",
+ common.LdapFilter:
"(&(objectClass=user)(sAMAccountName=%s))",
+ common.LdapGroupAttr: "memberOf",
+ common.LdapReturnAttr: "memberOf",
+ common.LdapBindUser:
"cn=admin,dc=example,dc=com",
+ common.LdapBindPassword: "password",
+ }
+
+ for key, value := range requiredFields {
+ err :=
os.WriteFile(filepath.Join(tmpDir, key), []byte(value), 0600)
+ if err != nil {
+ t.Fatalf("failed to write file
%s: %v", key, err)
+ }
+ }
+
+ origLdapMountPath := common.LdapMountPath
+ common.LdapMountPath = tmpDir
+ return tmpDir, func() { common.LdapMountPath =
origLdapMountPath }
+ },
+ expectedResult: true,
+ validateFunc: func(_ *testing.T, _ *LdapConfig) {
+ // No specific validation needed - we're
testing the return value
+ },
+ },
+ }
+}
+
+func TestUserGroupCacheLdap(t *testing.T) {
+ tests := []struct {
+ name string
+ validateFunc func(t *testing.T, cache *UserGroupCache)
+ }{
+ {
+ name: "Cache initialization",
+ validateFunc: func(t *testing.T, cache *UserGroupCache)
{
+ assert.Assert(t, cache != nil)
+ assert.Assert(t, cache.ugs != nil)
+ assert.Assert(t, cache.lookup != nil)
+ assert.Assert(t, cache.lookupGroupID != nil)
+ assert.Assert(t, cache.groupIds != nil)
+ },
+ },
+ {
+ name: "Cache interval",
+ validateFunc: func(t *testing.T, cache *UserGroupCache)
{
+ interval := cache.interval
+ expectedInterval := cleanerInterval *
time.Second // 60 seconds
+ assert.Equal(t, expectedInterval, interval,
"LDAP resolver interval should be 60 seconds")
+ },
+ },
+ }
+
+ for _, tt := range tests {
+ t.Run(tt.name, func(t *testing.T) {
+ // Get the LDAP user group cache
+ cache := GetUserGroupCacheLdap(&ConfigReaderMock{},
newMockLdapAccess(nil, nil))
+ // Run the validation function
+ tt.validateFunc(t, cache)
+ })
+ }
+}
+
+func TestMockLdapSearchResult(t *testing.T) {
+ // Test valid users
+ testCases := []struct {
+ username string
+ expectedCount int
+ expectError bool
+ }{
+ {"testuser1", 1, false},
+ {"testuser", 1, false},
+ {"testuser2", 2, false},
+ {"testuser3", 4, false},
+ {"testuser4", 2, false},
+ {"unknown", 0, true},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.username, func(t *testing.T) {
+ result, err := mockLdapSearchResult(tc.username)
+
+ if tc.expectError {
+ assert.Assert(t, err != nil, "Expected error
for user %s but got none", tc.username)
+ assert.Assert(t, result == nil, "Expected nil
result for user %s but got %v", tc.username, result)
+ assert.ErrorContains(t, err, "ldap lookup
failed for user: "+tc.username)
+ } else {
+ assert.NilError(t, err, "Unexpected error for
user %s: %v", tc.username, err)
+ assert.Assert(t, result != nil, "Expected
non-nil result for user %s", tc.username)
+ assert.Assert(t, len(result.Entries) > 0,
"Expected entries for user %s", tc.username)
+ assert.Assert(t,
len(result.Entries[0].Attributes) > 0, "Expected attributes for user %s",
tc.username)
+
+ memberOfAttr := result.Entries[0].Attributes[0]
+ assert.Equal(t, "memberOf", memberOfAttr.Name,
"Expected 'memberOf' attribute for user %s", tc.username)
+ assert.Equal(t, tc.expectedCount,
len(memberOfAttr.Values),
+ "Expected %d group values for user %s
but got %d",
+ tc.expectedCount, tc.username,
len(memberOfAttr.Values))
+ }
+ })
+ }
+}
+
+func TestLdapAccessImpl(t *testing.T) {
+ // Create a mock LDAP access implementation
+ mockAccess := &LdapAccessMock{
+ DialURLFunc: func(url string, options ...ldap.DialOpt)
(*ldap.Conn, error) {
+ return &ldap.Conn{}, nil
+ },
+ BindFunc: func(conn *ldap.Conn, username, password string)
error {
+ return nil
+ },
+ SearchFunc: func(conn *ldap.Conn, searchRequest
*ldap.SearchRequest) (*ldap.SearchResult, error) {
+ return &ldap.SearchResult{}, nil
+ },
+ CloseFunc: func(conn *ldap.Conn) {},
+ }
+
+ assert.Assert(t, mockAccess != nil)
+ conn, err := mockAccess.DialURL("testurl")
+ assert.NilError(t, err)
+ assert.Assert(t, conn != nil)
+ assert.NilError(t, mockAccess.Bind(&ldap.Conn{}, "user", "pass"))
+ result, err := mockAccess.Search(&ldap.Conn{}, &ldap.SearchRequest{})
+ assert.NilError(t, err)
+ assert.Assert(t, result != nil)
+ mockAccess.Close(&ldap.Conn{})
+}
+
+// TestLdapAccessImplMethods tests the ldapAccessImpl methods
+func TestLdapAccessImplMethods(t *testing.T) {
+ // Create a real implementation
+ impl := &ldapAccessImpl{}
+
+ // We can't actually connect to an LDAP server in unit tests,
+ // but we can verify the methods don't panic when called with nil
+
+ // Test DialURL - should return error with invalid URL
+ conn, err := impl.DialURL("invalid://url")
+ assert.Assert(t, err != nil)
+ assert.Assert(t, conn == nil)
+
+ // Other methods would panic if called with nil, so we can't test them
directly
+ // In a real scenario, we'd use a mock LDAP server or dependency
injection
+}
diff --git a/pkg/common/security/usergroup_no_resolver_test.go
b/pkg/common/security/usergroup_no_resolver_test.go
new file mode 100644
index 00000000..1e78eb07
--- /dev/null
+++ b/pkg/common/security/usergroup_no_resolver_test.go
@@ -0,0 +1,100 @@
+/*
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+*/
+
+package security
+
+import (
+ "os/user"
+ "testing"
+
+ "gotest.tools/v3/assert"
+)
+
+func TestNoLookupUser(t *testing.T) {
+ // Test with various usernames
+ testCases := []struct {
+ name string
+ username string
+ }{
+ {"Empty username", ""},
+ {"Standard username", "testuser"},
+ {"Username with special chars", "test-user_123"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ u, err := noLookupUser(tc.username)
+
+ // Should never return error
+ assert.NilError(t, err)
+
+ // Verify user properties
+ assert.Equal(t, tc.username, u.Username)
+ assert.Equal(t, "-1", u.Uid)
+ assert.Equal(t, tc.username, u.Gid)
+ })
+ }
+}
+
+func TestNoLookupGroupID(t *testing.T) {
+ // Test with various group IDs
+ testCases := []struct {
+ name string
+ gid string
+ }{
+ {"Empty GID", ""},
+ {"Numeric GID", "1000"},
+ {"String GID", "users"},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ g, err := noLookupGroupID(tc.gid)
+
+ // Should never return error
+ assert.NilError(t, err)
+
+ // Verify group properties
+ assert.Equal(t, tc.gid, g.Gid)
+ assert.Equal(t, tc.gid, g.Name)
+ })
+ }
+}
+
+func TestNoLookupGroupIds(t *testing.T) {
+ // Test with various users
+ testCases := []struct {
+ name string
+ user *user.User
+ }{
+ {"Standard user", &user.User{Username: "testuser", Uid: "1000",
Gid: "1000"}},
+ {"Empty user", &user.User{}},
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ groups, err := noLookupGroupIds(tc.user)
+
+ // Should never return error
+ assert.NilError(t, err)
+
+ // Should always return empty slice
+ assert.Equal(t, 0, len(groups))
+ })
+ }
+}
diff --git a/pkg/common/constants.go
b/pkg/common/security/usergroup_os_resolver_test.go
similarity index 55%
copy from pkg/common/constants.go
copy to pkg/common/security/usergroup_os_resolver_test.go
index 7eae1c79..acd7d5ab 100644
--- a/pkg/common/constants.go
+++ b/pkg/common/security/usergroup_os_resolver_test.go
@@ -16,17 +16,29 @@
limitations under the License.
*/
-package common
-
-const (
- Empty = ""
-
- Wildcard = "*"
- Separator = ","
- Space = " "
- AnonymousUser = "nobody"
- AnonymousGroup = "nogroup"
- RecoveryQueue = "@recovery@"
- RecoveryQueueFull = "root." + RecoveryQueue
- DefaultPlacementQueue = "root.default"
+package security
+
+import (
+ "os/user"
+ "testing"
)
+
+func TestWrappedGroupIds(t *testing.T) {
+ // Create a mock user
+ // Note: This test will behave differently depending on the system
+ // We'll just verify it doesn't panic and returns the expected type
+ u := &user.User{
+ Username: "testuser",
+ Uid: "1000",
+ Gid: "1000",
+ }
+
+ // Call the function - we can't predict the exact result
+ groups, err := wrappedGroupIds(u)
+
+ // Log the result for informational purposes
+ t.Logf("Groups: %v, Error: %v", groups, err)
+
+ // We can only verify the function doesn't panic
+ // The actual result depends on the OS and user configuration
+}
diff --git a/pkg/common/security/usergroup_test.go
b/pkg/common/security/usergroup_test.go
index 6c0d9f45..2a96fab8 100644
--- a/pkg/common/security/usergroup_test.go
+++ b/pkg/common/security/usergroup_test.go
@@ -28,6 +28,7 @@ import (
"gotest.tools/v3/assert"
"github.com/apache/yunikorn-core/pkg/common"
+ "github.com/apache/yunikorn-core/pkg/common/configs"
"github.com/apache/yunikorn-scheduler-interface/lib/go/si"
)
@@ -50,308 +51,494 @@ func (c *UserGroupCache) getUGmap() map[string]*UserGroup
{
return c.ugs
}
-func TestGetUserGroupCache(t *testing.T) {
- // get the cache with the test resolver set
- testCache := GetUserGroupCache("test")
- assert.Assert(t, testCache != nil, "Cache create failed")
- assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v",
testCache.getUGmap())
-
- testCache.Stop()
- assert.Assert(t, instance == nil, "instance should be nil")
- assert.Assert(t, stopped.Load())
-
- // get the cache with the os resolver set
- testCache = GetUserGroupCache("os")
- assert.Assert(t, testCache != nil, "Cache create failed")
- assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v",
testCache.getUGmap())
-
- testCache.Stop()
- assert.Assert(t, instance == nil, "instance should be nil")
- assert.Assert(t, stopped.Load())
-
- // get the cache with the default resolver set
- testCache = GetUserGroupCache("unknown")
- assert.Assert(t, testCache != nil, "Cache create failed")
- assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v",
testCache.getUGmap())
-
- testCache.Stop()
- assert.Assert(t, instance == nil, "instance should be nil")
- assert.Assert(t, stopped.Load())
-
- // test for re stop again
- testCache.Stop()
- assert.Assert(t, instance == nil, "instance should be nil")
- assert.Assert(t, stopped.Load())
+// UserGroupResolver Config for the test
+var testResolver = configs.UserGroupResolver{
+ Type: "test",
}
-func TestGetUserGroup(t *testing.T) {
- testCache := GetUserGroupCache("test")
- testCache.resetCache()
- // test cache should be empty now
- assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v",
testCache.getUGmap())
- ugi := &si.UserGroupInformation{
- User: "testuser1",
- Groups: nil,
- }
- ug, err := testCache.GetUserGroup(ugi.User)
- assert.NilError(t, err, "Lookup should not have failed: testuser1")
- if ug.failed {
- t.Errorf("lookup failed which should not have: %t", ug.failed)
- }
- if len(testCache.ugs) != 1 {
- t.Errorf("Cache not updated should have 1 entry %d",
len(testCache.ugs))
- }
- // check returned info: primary and secondary groups etc
- if ug.User != ugi.User || len(ug.Groups) != 2 || ug.resolved == 0 ||
ug.failed {
- t.Errorf("User 'testuser1' not resolved correctly: %v", ug)
- }
- testCache.lock.Lock()
- cachedUG := testCache.ugs[ugi.User]
- if ug.resolved != cachedUG.resolved {
- t.Errorf("User 'testuser1' not cached correctly resolution time
differs: %d got %d", ug.resolved, cachedUG.resolved)
- }
- // click over the clock: if we do not get the cached version the new
time will differ from the cache update
- cachedUG.resolved -= 5
- testCache.lock.Unlock()
-
- ug, err = testCache.GetUserGroup(ugi.User)
- if err != nil || ug.resolved != cachedUG.resolved {
- t.Errorf("User 'testuser1' not returned from Cache, resolution
time differs: %d got %d (err = %v)", ug.resolved, cachedUG.resolved, err)
- }
+// UserGroupResolver Config for the os resolver
+var osResolver = configs.UserGroupResolver{
+ Type: "os",
}
-func TestBrokenUserGroup(t *testing.T) {
- testCache := GetUserGroupCache("test")
- testCache.resetCache()
- // test cache should be empty now
- assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v",
testCache.getUGmap())
-
- ug, err := testCache.GetUserGroup("testuser2")
- if err != nil {
- t.Error("Lookup should not have failed: testuser2")
- }
+// UserGroupResolver Config for the unknown resolver
+var unknownResolver = configs.UserGroupResolver{
+ Type: "unknown",
+}
- assert.Equal(t, 1, testCache.getUGsize(), "Cache not updated should
have 1 entry %d", testCache.getUGmap())
- // check returned info: 3 groups etc
- if ug.User != "testuser2" || len(ug.Groups) != 3 || ug.resolved == 0 ||
ug.failed {
- t.Errorf("User 'testuser2' not resolved correctly: %v", ug)
- }
- // first group should have failed resolution: just the ID expected
- if ug.Groups[0] != "100" {
- t.Errorf("User 'testuser2' primary group resolved while it
should not: %v", ug)
- }
+// UserGroupResolver Config for the LDAP resolver
+var ldapResolver = configs.UserGroupResolver{
+ Type: "ldap",
+}
- ug, err = testCache.GetUserGroup("testuser3")
- if err != nil {
- t.Error("Lookup should not have failed: testuser3")
+func TestGetUserGroupCache(t *testing.T) {
+ testCases := []struct {
+ name string
+ resolver configs.UserGroupResolver
+ }{
+ {
+ name: "TestResolver",
+ resolver: testResolver,
+ },
+ {
+ name: "OsResolver",
+ resolver: osResolver,
+ },
+ {
+ name: "UnknownResolver",
+ resolver: unknownResolver,
+ },
+ {
+ name: "LdapResolver",
+ resolver: ldapResolver,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ // Get the cache with the resolver set
+ testCache := GetUserGroupCache(tc.resolver,
&ConfigReaderMock{}, &LdapAccessMock{})
+ assert.Assert(t, testCache != nil, "Cache create
failed")
+ assert.Equal(t, 0, testCache.getUGsize(), "Cache is not
empty: %v", testCache.getUGmap())
+
+ testCache.Stop()
+ assert.Assert(t, instance == nil, "instance should be
nil")
+ assert.Assert(t, stopped.Load())
+
+ // Test for re-stop
+ testCache.Stop()
+ assert.Assert(t, instance == nil, "instance should be
nil")
+ assert.Assert(t, stopped.Load())
+ })
}
-
- assert.Equal(t, 2, testCache.getUGsize(), "Cache not updated should
have 2 entries %d", len(testCache.ugs))
- assert.Equal(t, 4, testCache.getUGGroupSize("testuser3"), "User
'testuser3' not resolved correctly: duplicate primary group not filtered %v",
ug)
-
- ug, err = testCache.GetUserGroup("unknown")
- assert.ErrorContains(t, err, "lookup failed for user: unknown")
-
- ug, err = testCache.GetUserGroup("testuser4")
- assert.NilError(t, err)
-
- ug, err = testCache.GetUserGroup("testuser5")
- assert.ErrorContains(t, err, "lookup failed for user: testuser5")
-
- ug, err = testCache.GetUserGroup("invalid-gid-user")
- assert.ErrorContains(t, err, "lookup failed for user: invalid-gid-user")
- exceptedGroup := []string{"1_001"}
- assert.Assert(t, reflect.DeepEqual(ug.Groups, exceptedGroup),
fmt.Errorf("group should be: %v, but got: %v", exceptedGroup, ug.Groups))
}
-func TestGetUserGroupFail(t *testing.T) {
- testCache := GetUserGroupCache("test")
- testCache.resetCache()
- // test cache should be empty now
- assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v",
testCache.getUGmap())
-
- // resolve an empty user
- ug, err := testCache.GetUserGroup("")
- if err == nil {
- t.Error("Lookup should have failed: empty user")
- }
- // ug is empty everything should be nil..
- if ug.User != "" || len(ug.Groups) != 0 || ug.resolved != 0 ||
ug.failed {
- t.Errorf("UserGroup is not empty: %v", ug)
- }
+// Tests for the LDAP resolver using the mock implementation
- // resolve a non existing user
- ugi := &si.UserGroupInformation{
- User: "unknown",
- Groups: nil,
- }
- ug, err = testCache.GetUserGroup(ugi.User)
- if err == nil {
- t.Error("Lookup should have failed: unknown user")
- }
- // ug is partially filled and failed flag is set
- if ug.User != ugi.User || len(ug.Groups) != 0 || !ug.failed {
- t.Errorf("UserGroup is not empty: %v", ug)
+func TestGetUserGroup(t *testing.T) {
+ testCases := []struct {
+ name string
+ resolver configs.UserGroupResolver
+ }{
+ {
+ name: "TestResolver",
+ resolver: testResolver,
+ },
+ {
+ name: "OsResolver",
+ resolver: osResolver,
+ },
+ {
+ name: "UnknownResolver",
+ resolver: unknownResolver,
+ },
+ {
+ name: "LdapResolver",
+ resolver: ldapResolver,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ testCache := GetUserGroupCache(tc.resolver,
&ConfigReaderMock{}, &LdapAccessMock{})
+ testCache.resetCache()
+ // test cache should be empty now
+ assert.Equal(t, 0, testCache.getUGsize(), "Cache is not
empty: %v", testCache.getUGmap())
+ ugi := &si.UserGroupInformation{
+ User: "testuser1",
+ Groups: nil,
+ }
+ ug, err := testCache.GetUserGroup(ugi.User)
+ assert.NilError(t, err, "Lookup should not have failed:
testuser1")
+ if ug.failed {
+ t.Errorf("lookup failed which should not have:
%t", ug.failed)
+ }
+ if len(testCache.ugs) != 1 {
+ t.Errorf("Cache not updated should have 1 entry
%d", len(testCache.ugs))
+ }
+ // check returned info: primary and secondary groups etc
+ if ug.User != ugi.User || len(ug.Groups) != 2 ||
ug.resolved == 0 || ug.failed {
+ t.Errorf("User 'testuser1' not resolved
correctly: %v", ug)
+ }
+ testCache.lock.Lock()
+ cachedUG := testCache.ugs[ugi.User]
+ if ug.resolved != cachedUG.resolved {
+ t.Errorf("User 'testuser1' not cached correctly
resolution time differs: %d got %d", ug.resolved, cachedUG.resolved)
+ }
+ // click over the clock: if we do not get the cached
version the new time will differ from the cache update
+ cachedUG.resolved -= 5
+ testCache.lock.Unlock()
+
+ ug, err = testCache.GetUserGroup(ugi.User)
+ if err != nil || ug.resolved != cachedUG.resolved {
+ t.Errorf("User 'testuser1' not returned from
Cache, resolution time differs: %d got %d (err = %v)", ug.resolved,
cachedUG.resolved, err)
+ }
+ })
}
+}
- ug, err = testCache.GetUserGroup(ugi.User)
- if err == nil {
- t.Error("Lookup should have failed: unknown user")
- }
- // ug is partially filled and failed flag is set: error message should
show that the cache was returned
- if err != nil && !strings.Contains(err.Error(), "cached data returned")
{
- t.Errorf("UserGroup not returned from Cache: %v, error: %v",
ug, err)
+func TestBrokenUserGroup(t *testing.T) {
+ testCases := []struct {
+ name string
+ resolver configs.UserGroupResolver
+ }{
+ {
+ name: "TestResolver",
+ resolver: testResolver,
+ },
+ {
+ name: "OsResolver",
+ resolver: osResolver,
+ },
+ {
+ name: "UnknownResolver",
+ resolver: unknownResolver,
+ },
+ {
+ name: "LdapResolver",
+ resolver: ldapResolver,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ testCache := GetUserGroupCache(tc.resolver,
&ConfigReaderMock{}, &LdapAccessMock{})
+ testCache.resetCache()
+ // test cache should be empty now
+ assert.Equal(t, 0, testCache.getUGsize(), "Cache is not
empty: %v", testCache.getUGmap())
+
+ ug, err := testCache.GetUserGroup("testuser2")
+ if err != nil {
+ t.Error("Lookup should not have failed:
testuser2")
+ }
+
+ assert.Equal(t, 1, testCache.getUGsize(), "Cache not
updated should have 1 entry %d", testCache.getUGmap())
+ // check returned info: 3 groups etc
+ if ug.User != "testuser2" || len(ug.Groups) != 3 ||
ug.resolved == 0 || ug.failed {
+ t.Errorf("User 'testuser2' not resolved
correctly: %v", ug)
+ }
+ // first group should have failed resolution: just the
ID expected
+ if ug.Groups[0] != "100" {
+ t.Errorf("User 'testuser2' primary group
resolved while it should not: %v", ug)
+ }
+
+ ug, err = testCache.GetUserGroup("testuser3")
+ if err != nil {
+ t.Error("Lookup should not have failed:
testuser3")
+ }
+
+ assert.Equal(t, 2, testCache.getUGsize(), "Cache not
updated should have 2 entries %d", len(testCache.ugs))
+ assert.Equal(t, 4,
testCache.getUGGroupSize("testuser3"), "User 'testuser3' not resolved
correctly: duplicate primary group not filtered %v", ug)
+
+ ug, err = testCache.GetUserGroup("unknown")
+ assert.ErrorContains(t, err, "lookup failed for user:
unknown")
+
+ ug, err = testCache.GetUserGroup("testuser4")
+ assert.NilError(t, err)
+
+ ug, err = testCache.GetUserGroup("testuser5")
+ assert.ErrorContains(t, err, "lookup failed for user:
testuser5")
+
+ ug, err = testCache.GetUserGroup("invalid-gid-user")
+ assert.ErrorContains(t, err, "lookup failed for user:
invalid-gid-user")
+ exceptedGroup := []string{"1_001"}
+ assert.Assert(t, reflect.DeepEqual(ug.Groups,
exceptedGroup), fmt.Errorf("group should be: %v, but got: %v", exceptedGroup,
ug.Groups))
+ })
}
}
-func TestCacheCleanUp(t *testing.T) {
- testCache := GetUserGroupCache("test")
- testCache.resetCache()
- // test cache should be empty now
- assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v",
testCache.getUGmap())
-
- // resolve an existing user
- _, err := testCache.GetUserGroup("testuser1")
- if err != nil {
- t.Error("Lookup should not have failed: testuser1 user")
- }
- _, err = testCache.GetUserGroup("testuser2")
- if err != nil {
- t.Error("Lookup should not have failed: testuser2 user")
+func TestGetUserGroupFail(t *testing.T) {
+ testCases := []struct {
+ name string
+ resolver configs.UserGroupResolver
+ }{
+ {
+ name: "TestResolver",
+ resolver: testResolver,
+ },
+ {
+ name: "OsResolver",
+ resolver: osResolver,
+ },
+ {
+ name: "UnknownResolver",
+ resolver: unknownResolver,
+ },
+ {
+ name: "LdapResolver",
+ resolver: ldapResolver,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ testCache := GetUserGroupCache(tc.resolver,
&ConfigReaderMock{}, &LdapAccessMock{})
+ testCache.resetCache()
+ // test cache should be empty now
+ assert.Equal(t, 0, testCache.getUGsize(), "Cache is not
empty: %v", testCache.getUGmap())
+
+ // resolve an empty user
+ ug, err := testCache.GetUserGroup("")
+ if err == nil {
+ t.Error("Lookup should have failed: empty user")
+ }
+ // ug is empty everything should be nil..
+ if ug.User != "" || len(ug.Groups) != 0 || ug.resolved
!= 0 || ug.failed {
+ t.Errorf("UserGroup is not empty: %v", ug)
+ }
+
+ // resolve a non existing user
+ ugi := &si.UserGroupInformation{
+ User: "unknown",
+ Groups: nil,
+ }
+ ug, err = testCache.GetUserGroup(ugi.User)
+ if err == nil {
+ t.Error("Lookup should have failed: unknown
user")
+ }
+ // ug is partially filled and failed flag is set
+ if ug.User != ugi.User || len(ug.Groups) != 0 ||
!ug.failed {
+ t.Errorf("UserGroup is not empty: %v", ug)
+ }
+
+ ug, err = testCache.GetUserGroup(ugi.User)
+ if err == nil {
+ t.Error("Lookup should have failed: unknown
user")
+ }
+ // ug is partially filled and failed flag is set: error
message should show that the cache was returned
+ if err != nil && !strings.Contains(err.Error(), "cached
data returned") {
+ t.Errorf("UserGroup not returned from Cache:
%v, error: %v", ug, err)
+ }
+ })
}
+}
- testCache.lock.Lock()
- ug := testCache.ugs["testuser1"]
- if ug.failed {
- t.Error("User 'testuser1' not resolved as a success")
- }
- // expire the successful lookup
- ug.resolved -= 2 * poscache
- testCache.lock.Unlock()
-
- // resolve a non existing user
- _, err = testCache.GetUserGroup("unknown")
- if err == nil {
- t.Error("Lookup should have failed: unknown user")
- }
- testCache.lock.Lock()
- ug = testCache.ugs["unknown"]
- if !ug.failed {
- t.Error("User 'unknown' not resolved as a failure")
+func TestCacheCleanUp(t *testing.T) {
+ testCases := []struct {
+ name string
+ resolver configs.UserGroupResolver
+ }{
+ {
+ name: "TestResolver",
+ resolver: testResolver,
+ },
+ {
+ name: "OsResolver",
+ resolver: osResolver,
+ },
+ {
+ name: "UnknownResolver",
+ resolver: unknownResolver,
+ },
+ {
+ name: "LdapResolver",
+ resolver: ldapResolver,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ testCache := GetUserGroupCache(tc.resolver,
&ConfigReaderMock{}, &LdapAccessMock{})
+ testCache.resetCache()
+ // test cache should be empty now
+ assert.Equal(t, 0, testCache.getUGsize(), "Cache is not
empty: %v", testCache.getUGmap())
+
+ // resolve an existing user
+ _, err := testCache.GetUserGroup("testuser1")
+ if err != nil {
+ t.Error("Lookup should not have failed:
testuser1 user")
+ }
+ _, err = testCache.GetUserGroup("testuser2")
+ if err != nil {
+ t.Error("Lookup should not have failed:
testuser2 user")
+ }
+
+ testCache.lock.Lock()
+ ug := testCache.ugs["testuser1"]
+ if ug.failed {
+ t.Error("User 'testuser1' not resolved as a
success")
+ }
+ // expire the successful lookup
+ ug.resolved -= 2 * poscache
+ testCache.lock.Unlock()
+
+ // resolve a non existing user
+ _, err = testCache.GetUserGroup("unknown")
+ if err == nil {
+ t.Error("Lookup should have failed: unknown
user")
+ }
+ testCache.lock.Lock()
+ ug = testCache.ugs["unknown"]
+ if !ug.failed {
+ t.Error("User 'unknown' not resolved as a
failure")
+ }
+ // expire the failed lookup
+ ug.resolved -= 2 * negcache
+ testCache.lock.Unlock()
+
+ testCache.cleanUpCache()
+ assert.Equal(t, 1, testCache.getUGsize(), "Cache is not
empty: %v", testCache.getUGmap())
+ })
}
- // expire the failed lookup
- ug.resolved -= 2 * negcache
- testCache.lock.Unlock()
-
- testCache.cleanUpCache()
- assert.Equal(t, 1, testCache.getUGsize(), "Cache is not empty: %v",
testCache.getUGmap())
}
func TestIntervalCacheCleanUp(t *testing.T) {
- testCache := GetUserGroupCache("test")
- testCache.resetCache()
- // test cache should be empty now
- assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v",
testCache.getUGmap())
-
- // resolve an existing user
- user1ug, err := testCache.GetUserGroup("testuser1")
- assert.NilError(t, err, "Lookup should not have failed: testuser1 user")
- assert.Assert(t, !user1ug.failed, "User 'testuser1' not resolved as a
success")
-
- _, err = testCache.GetUserGroup("testuser2")
- assert.NilError(t, err, "Lookup should not have failed: testuser1 user")
-
- // expire the successful lookup
- testCache.lock.Lock()
- ug := testCache.ugs["testuser1"]
- ug.resolved -= 2 * poscache
-
- testCache.lock.Unlock()
- // resolve a non existing user
- _, err = testCache.GetUserGroup("unknown")
- assert.Assert(t, err != nil, "Lookup should have failed: unknown user")
- testCache.lock.Lock()
- ug = testCache.ugs["unknown"]
- assert.Assert(t, ug.failed, "User 'unknown' not resolved as a failure")
-
- // expire the failed lookup
- ug.resolved -= 2 * negcache
- testCache.lock.Unlock()
-
- // sleep to wait for interval, it will trigger cleanUpCache
- time.Sleep(testCache.interval + time.Second)
- assert.Equal(t, 1, testCache.getUGsize(), "Cache not cleaned up : %v",
testCache.getUGmap())
+ testCases := []struct {
+ name string
+ resolver configs.UserGroupResolver
+ }{
+ {
+ name: "TestResolver",
+ resolver: testResolver,
+ },
+ {
+ name: "OsResolver",
+ resolver: osResolver,
+ },
+ {
+ name: "UnknownResolver",
+ resolver: unknownResolver,
+ },
+ {
+ name: "LdapResolver",
+ resolver: ldapResolver,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ testCache := GetUserGroupCache(tc.resolver,
&ConfigReaderMock{}, &LdapAccessMock{})
+ testCache.resetCache()
+ // test cache should be empty now
+ assert.Equal(t, 0, testCache.getUGsize(), "Cache is not
empty: %v", testCache.getUGmap())
+
+ // resolve an existing user
+ user1ug, err := testCache.GetUserGroup("testuser1")
+ assert.NilError(t, err, "Lookup should not have failed:
testuser1 user")
+ assert.Assert(t, !user1ug.failed, "User 'testuser1' not
resolved as a success")
+
+ _, err = testCache.GetUserGroup("testuser2")
+ assert.NilError(t, err, "Lookup should not have failed:
testuser1 user")
+
+ // expire the successful lookup
+ testCache.lock.Lock()
+ ug := testCache.ugs["testuser1"]
+ ug.resolved -= 2 * poscache
+
+ testCache.lock.Unlock()
+ // resolve a non existing user
+ _, err = testCache.GetUserGroup("unknown")
+ assert.Assert(t, err != nil, "Lookup should have
failed: unknown user")
+ testCache.lock.Lock()
+ ug = testCache.ugs["unknown"]
+ assert.Assert(t, ug.failed, "User 'unknown' not
resolved as a failure")
+
+ // expire the failed lookup
+ ug.resolved -= 2 * negcache
+ testCache.lock.Unlock()
+
+ // sleep to wait for interval, it will trigger
cleanUpCache
+ time.Sleep(testCache.interval + time.Second)
+ assert.Equal(t, 1, testCache.getUGsize(), "Cache not
cleaned up : %v", testCache.getUGmap())
+ })
+ }
}
func TestConvertUGI(t *testing.T) {
- testCache := GetUserGroupCache("test")
- testCache.resetCache()
- // test cache should be empty now
- assert.Equal(t, 0, testCache.getUGsize(), "Cache is not empty: %v",
testCache.getUGmap())
-
- ugi := &si.UserGroupInformation{
- User: "",
- Groups: nil,
- }
- ug, err := testCache.ConvertUGI(ugi, false)
- if err == nil {
- t.Errorf("empty user convert should have failed and did not:
%v", ug)
- }
- // try known user without groups
- ugi.User = "testuser1"
- ug, err = testCache.ConvertUGI(ugi, false)
- if err != nil {
- t.Errorf("known user, no groups, convert should not have
failed: %v", err)
- }
- if ug.User != "testuser1" || len(ug.Groups) != 2 || ug.resolved == 0 ||
ug.failed {
- t.Errorf("User 'testuser1' not resolved correctly: %v", ug)
- }
- // try unknown user without groups
- ugi.User = "unknown"
- ug, err = testCache.ConvertUGI(ugi, false)
- if err == nil {
- t.Errorf("unknown user, no groups, convert should have failed:
%v", ug)
- }
- // try empty user when forced
- ugi.User = ""
- ug, err = testCache.ConvertUGI(ugi, true)
- if err != nil {
- t.Errorf("empty user but forced, convert should not have
failed: %v", err)
+ testCases := []struct {
+ name string
+ resolver configs.UserGroupResolver
+ }{
+ {
+ name: "TestResolver",
+ resolver: testResolver,
+ },
+ {
+ name: "OsResolver",
+ resolver: osResolver,
+ },
+ {
+ name: "UnknownResolver",
+ resolver: unknownResolver,
+ },
+ {
+ name: "LdapResolver",
+ resolver: ldapResolver,
+ },
+ }
+
+ for _, tc := range testCases {
+ t.Run(tc.name, func(t *testing.T) {
+ testCache := GetUserGroupCache(tc.resolver,
&ConfigReaderMock{}, &LdapAccessMock{})
+ testCache.resetCache()
+ // test cache should be empty now
+ assert.Equal(t, 0, testCache.getUGsize(), "Cache is not
empty: %v", testCache.getUGmap())
+
+ ugi := &si.UserGroupInformation{
+ User: "",
+ Groups: nil,
+ }
+ ug, err := testCache.ConvertUGI(ugi, false)
+ if err == nil {
+ t.Errorf("empty user convert should have failed
and did not: %v", ug)
+ }
+ // try known user without groups
+ ugi.User = "testuser1"
+ ug, err = testCache.ConvertUGI(ugi, false)
+ if err != nil {
+ t.Errorf("known user, no groups, convert should
not have failed: %v", err)
+ }
+ if ug.User != "testuser1" || len(ug.Groups) != 2 ||
ug.resolved == 0 || ug.failed {
+ t.Errorf("User 'testuser1' not resolved
correctly: %v", ug)
+ }
+ // try unknown user without groups
+ ugi.User = "unknown"
+ ug, err = testCache.ConvertUGI(ugi, false)
+ if err == nil {
+ t.Errorf("unknown user, no groups, convert
should have failed: %v", ug)
+ }
+ // try empty user when forced
+ ugi.User = ""
+ ug, err = testCache.ConvertUGI(ugi, true)
+ if err != nil {
+ t.Errorf("empty user but forced, convert should
not have failed: %v", err)
+ }
+ // try unknown user with groups
+ ugi.User = "unknown2"
+ group := "passedin"
+ ugi.Groups = []string{group}
+ ug, err = testCache.ConvertUGI(ugi, false)
+ if err != nil {
+ t.Errorf("unknown user with groups, convert
should not have failed: %v", err)
+ }
+ if ug.User != "unknown2" || len(ug.Groups) != 1 ||
ug.resolved == 0 || ug.failed {
+ t.Fatalf("User 'unknown2' not resolved
correctly: %v", ug)
+ }
+ if ug.Groups[0] != group {
+ t.Errorf("groups not initialised correctly on
convert: expected '%s' got '%s'", group, ug.Groups[0])
+ }
+ // try valid username with groups
+ ugi.User = "validuserABCD1234@://#"
+ ugi.Groups = []string{group}
+ ug, err = testCache.ConvertUGI(ugi, false)
+ if err != nil {
+ t.Errorf("valid username with groups, convert
should not have failed: %v", err)
+ }
+ // try invalid username with groups
+ ugi.User = "invaliduser><+"
+ ugi.Groups = []string{group}
+ ug, err = testCache.ConvertUGI(ugi, false)
+ if err == nil {
+ t.Errorf("invalid username, convert should have
failed: %v", err)
+ }
+
+ // try unknown user with empty group when forced
+ ugi.User = "unknown"
+ ugi.Groups = []string{}
+ ug, err = testCache.ConvertUGI(ugi, true)
+ exceptedGroup := []string{common.AnonymousGroup}
+ assert.Assert(t, reflect.DeepEqual(ug.Groups,
exceptedGroup), "group should be: %v, but got: %v", exceptedGroup, ug.Groups)
+ assert.NilError(t, err, "unknown user, no groups,
convert should not have failed")
+ })
}
- // try unknown user with groups
- ugi.User = "unknown2"
- group := "passedin"
- ugi.Groups = []string{group}
- ug, err = testCache.ConvertUGI(ugi, false)
- if err != nil {
- t.Errorf("unknown user with groups, convert should not have
failed: %v", err)
- }
- if ug.User != "unknown2" || len(ug.Groups) != 1 || ug.resolved == 0 ||
ug.failed {
- t.Fatalf("User 'unknown2' not resolved correctly: %v", ug)
- }
- if ug.Groups[0] != group {
- t.Errorf("groups not initialised correctly on convert: expected
'%s' got '%s'", group, ug.Groups[0])
- }
- // try valid username with groups
- ugi.User = "validuserABCD1234@://#"
- ugi.Groups = []string{group}
- ug, err = testCache.ConvertUGI(ugi, false)
- if err != nil {
- t.Errorf("valid username with groups, convert should not have
failed: %v", err)
- }
- // try invalid username with groups
- ugi.User = "invaliduser><+"
- ugi.Groups = []string{group}
- ug, err = testCache.ConvertUGI(ugi, false)
- if err == nil {
- t.Errorf("invalid username, convert should have failed: %v",
err)
- }
-
- // try unknown user with empty group when forced
- ugi.User = "unknown"
- ugi.Groups = []string{}
- ug, err = testCache.ConvertUGI(ugi, true)
- exceptedGroup := []string{common.AnonymousGroup}
- assert.Assert(t, reflect.DeepEqual(ug.Groups, exceptedGroup), "group
should be: %v, but got: %v", exceptedGroup, ug.Groups)
- assert.NilError(t, err, "unknown user, no groups, convert should not
have failed")
}
diff --git a/pkg/common/security/usergroup_test_resolver.go
b/pkg/common/security/usergroup_test_resolver.go
index c35dd6bb..47058c26 100644
--- a/pkg/common/security/usergroup_test_resolver.go
+++ b/pkg/common/security/usergroup_test_resolver.go
@@ -26,9 +26,12 @@ import (
)
const (
+ Testuser = "testuser"
Testuser1 = "testuser1"
Testuser2 = "testuser2"
Testuser3 = "testuser3"
+ Testuser4 = "testuser4"
+ Testuser5 = "testuser5"
)
// Get the cache with a test resolver
diff --git a/pkg/scheduler/partition.go b/pkg/scheduler/partition.go
index 62e40130..96c58943 100644
--- a/pkg/scheduler/partition.go
+++ b/pkg/scheduler/partition.go
@@ -136,7 +136,7 @@ func (pc *PartitionContext) initialPartitionFromConfig(conf
configs.PartitionCon
// Placing an application will not have a lock on the partition context.
pc.placementManager =
placement.NewPlacementManager(conf.PlacementRules, pc.GetQueue, silence)
// get the user group cache for the partition
- pc.userGroupCache = security.GetUserGroupCache("")
+ pc.userGroupCache = security.GetUserGroupCache(conf.UserGroupResolver,
security.GetConfigReader(), security.GetLdapAccess())
pc.updateNodeSortingPolicy(conf, silence)
pc.updatePreemption(conf)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]