diff --git a/internal/upstreamldap/upstreamldap_test.go b/internal/upstreamldap/upstreamldap_test.go index c8bdc3959..b4ee6bdfe 100644 --- a/internal/upstreamldap/upstreamldap_test.go +++ b/internal/upstreamldap/upstreamldap_test.go @@ -51,6 +51,8 @@ const ( testUserSearchResultUIDAttributeValue = "some-upstream-uid-value" testGroupSearchResultGroupNameAttributeValue1 = "some-upstream-group-name-value1" testGroupSearchResultGroupNameAttributeValue2 = "some-upstream-group-name-value2" + testUserDNWithSpecialChars = `user DN with * \ special characters ()` + testUserDNWithSpecialCharsEscaped = `user DN with \2a \5c special characters \28\29` expectedGroupSearchPageSize = uint32(250) ) @@ -529,7 +531,7 @@ func TestEndUserAuthentication(t *testing.T) { Return(&ldap.SearchResult{ Entries: []*ldap.Entry{ { - DN: `result DN with * \ special characters ()`, + DN: testUserDNWithSpecialChars, Attributes: []*ldap.EntryAttribute{ ldap.NewEntryAttribute(testUserSearchUsernameAttribute, []string{testUserSearchResultUsernameAttributeValue}), ldap.NewEntryAttribute(testUserSearchUIDAttribute, []string{testUserSearchResultUIDAttributeValue}), @@ -538,16 +540,16 @@ func TestEndUserAuthentication(t *testing.T) { }, }, nil).Times(1) conn.EXPECT().SearchWithPaging(expectedGroupSearch(func(r *ldap.SearchRequest) { - escapedDN := `result DN with \2a \5c special characters \28\29` + escapedDN := testUserDNWithSpecialCharsEscaped r.Filter = fmt.Sprintf("(some-group-filter=%s-and-more-filter=%s)", escapedDN, escapedDN) }), expectedGroupSearchPageSize).Return(exampleGroupSearchResult, nil).Times(1) conn.EXPECT().Close().Times(1) }, bindEndUserMocks: func(conn *mockldapconn.MockConn) { - conn.EXPECT().Bind(`result DN with * \ special characters ()`, testUpstreamPassword).Times(1) + conn.EXPECT().Bind(testUserDNWithSpecialChars, testUpstreamPassword).Times(1) }, wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) { - r.DN = `result DN with * \ special characters ()` + r.DN = testUserDNWithSpecialChars }), }, { @@ -563,7 +565,7 @@ func TestEndUserAuthentication(t *testing.T) { Return(&ldap.SearchResult{ Entries: []*ldap.Entry{ { - DN: `result DN with * \ special characters ()`, + DN: testUserDNWithSpecialChars, Attributes: []*ldap.EntryAttribute{ ldap.NewEntryAttribute(testUserSearchUsernameAttribute, []string{testUserSearchResultUsernameAttributeValue}), ldap.NewEntryAttribute(testUserSearchUIDAttribute, []string{testUserSearchResultUIDAttributeValue}), @@ -572,15 +574,15 @@ func TestEndUserAuthentication(t *testing.T) { }, }, nil).Times(1) conn.EXPECT().SearchWithPaging(expectedGroupSearch(func(r *ldap.SearchRequest) { - r.Filter = fmt.Sprintf("(member=%s)", `result DN with \2a \5c special characters \28\29`) + r.Filter = fmt.Sprintf("(member=%s)", testUserDNWithSpecialCharsEscaped) }), expectedGroupSearchPageSize).Return(exampleGroupSearchResult, nil).Times(1) conn.EXPECT().Close().Times(1) }, bindEndUserMocks: func(conn *mockldapconn.MockConn) { - conn.EXPECT().Bind(`result DN with * \ special characters ()`, testUpstreamPassword).Times(1) + conn.EXPECT().Bind(testUserDNWithSpecialChars, testUpstreamPassword).Times(1) }, wantAuthResponse: expectedAuthResponse(func(r *authenticators.Response) { - r.DN = `result DN with * \ special characters ()` + r.DN = testUserDNWithSpecialChars }), }, { @@ -1219,28 +1221,41 @@ func TestEndUserAuthentication(t *testing.T) { func TestUpstreamRefresh(t *testing.T) { pwdLastSetAttribute := "pwdLastSet" - expectedUserSearch := &ldap.SearchRequest{ - BaseDN: testUserSearchResultDNValue, - Scope: ldap.ScopeBaseObject, - DerefAliases: ldap.NeverDerefAliases, - SizeLimit: 2, - TimeLimit: 90, - TypesOnly: false, - Filter: "(objectClass=*)", - Attributes: []string{testUserSearchUsernameAttribute, testUserSearchUIDAttribute, pwdLastSetAttribute}, - Controls: nil, // don't need paging because we set the SizeLimit so small + + expectedUserSearch := func(editFunc func(r *ldap.SearchRequest)) *ldap.SearchRequest { + request := &ldap.SearchRequest{ + BaseDN: testUserSearchResultDNValue, + Scope: ldap.ScopeBaseObject, + DerefAliases: ldap.NeverDerefAliases, + SizeLimit: 2, + TimeLimit: 90, + TypesOnly: false, + Filter: "(objectClass=*)", + Attributes: []string{testUserSearchUsernameAttribute, testUserSearchUIDAttribute, pwdLastSetAttribute}, + Controls: nil, // don't need paging because we set the SizeLimit so small + } + if editFunc != nil { + editFunc(request) + } + return request } - expectedGroupSearch := &ldap.SearchRequest{ - BaseDN: testGroupSearchBase, - Scope: ldap.ScopeWholeSubtree, - DerefAliases: ldap.NeverDerefAliases, - SizeLimit: 0, // unlimited size because we will search with paging - TimeLimit: 90, - TypesOnly: false, - Filter: testGroupSearchFilterInterpolated, - Attributes: []string{testGroupSearchGroupNameAttribute}, - Controls: nil, // nil because ldap.SearchWithPaging() will set the appropriate controls for us + expectedGroupSearch := func(editFunc func(r *ldap.SearchRequest)) *ldap.SearchRequest { + request := &ldap.SearchRequest{ + BaseDN: testGroupSearchBase, + Scope: ldap.ScopeWholeSubtree, + DerefAliases: ldap.NeverDerefAliases, + SizeLimit: 0, // unlimited size because we will search with paging + TimeLimit: 90, + TypesOnly: false, + Filter: testGroupSearchFilterInterpolated, + Attributes: []string{testGroupSearchGroupNameAttribute}, + Controls: nil, // nil because ldap.SearchWithPaging() will set the appropriate controls for us + } + if editFunc != nil { + editFunc(request) + } + return request } happyPathUserSearchResult := &ldap.SearchResult{ @@ -1266,116 +1281,170 @@ func TestUpstreamRefresh(t *testing.T) { }, } - providerConfig := &ProviderConfig{ - Name: "some-provider-name", - Host: testHost, - CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test - ConnectionProtocol: TLS, - BindUsername: testBindUsername, - BindPassword: testBindPassword, - UserSearch: UserSearchConfig{ - Base: testUserSearchBase, - UIDAttribute: testUserSearchUIDAttribute, - UsernameAttribute: testUserSearchUsernameAttribute, - }, - RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{ - pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute), + happyPathGroupSearchResult := &ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testGroupSearchResultDNValue1, + Attributes: []*ldap.EntryAttribute{ + ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue1}), + }, + }, + { + DN: testGroupSearchResultDNValue2, + Attributes: []*ldap.EntryAttribute{ + ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue2}), + }, + }, }, + Referrals: []string{}, // note that we are not following referrals at this time + Controls: []ldap.Control{}, + } + + providerConfig := func(editFunc func(p *ProviderConfig)) *ProviderConfig { + config := &ProviderConfig{ + Name: "some-provider-name", + Host: testHost, + CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test + ConnectionProtocol: TLS, + BindUsername: testBindUsername, + BindPassword: testBindPassword, + UserSearch: UserSearchConfig{ + Base: testUserSearchBase, + UIDAttribute: testUserSearchUIDAttribute, + UsernameAttribute: testUserSearchUsernameAttribute, + }, + GroupSearch: GroupSearchConfig{ + Base: testGroupSearchBase, + Filter: testGroupSearchFilter, + GroupNameAttribute: testGroupSearchGroupNameAttribute, + }, + RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{ + pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute), + }, + } + if editFunc != nil { + editFunc(config) + } + return config } tests := []struct { name string providerConfig *ProviderConfig setupMocks func(conn *mockldapconn.MockConn) + refreshUserDN string dialError error wantErr string wantGroups []string }{ { - name: "happy path where searching the dn returns a single entry", - providerConfig: providerConfig, + name: "happy path without group search where searching the dn returns a single entry", + providerConfig: providerConfig(func(p *ProviderConfig) { + p.GroupSearch = GroupSearchConfig{} + }), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1) + conn.EXPECT().Search(expectedUserSearch(nil)).Return(happyPathUserSearchResult, nil).Times(1) conn.EXPECT().Close().Times(1) }, wantGroups: []string{}, }, { - name: "happy path where group search returns groups", - providerConfig: &ProviderConfig{ - Name: "some-provider-name", - Host: testHost, - CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test - ConnectionProtocol: TLS, - BindUsername: testBindUsername, - BindPassword: testBindPassword, - UserSearch: UserSearchConfig{ - Base: testUserSearchBase, - UIDAttribute: testUserSearchUIDAttribute, - UsernameAttribute: testUserSearchUsernameAttribute, - }, - GroupSearch: GroupSearchConfig{ - Base: testGroupSearchBase, - Filter: testGroupSearchFilter, - GroupNameAttribute: testGroupSearchGroupNameAttribute, - }, - RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{ - pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute), - }, + name: "happy path where group search returns groups", + providerConfig: providerConfig(nil), + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch(nil)).Return(happyPathUserSearchResult, nil).Times(1) + conn.EXPECT().SearchWithPaging(expectedGroupSearch(nil), expectedGroupSearchPageSize).Return(happyPathGroupSearchResult, nil).Times(1) + conn.EXPECT().Close().Times(1) }, + wantGroups: []string{testGroupSearchResultGroupNameAttributeValue1, testGroupSearchResultGroupNameAttributeValue2}, + }, + { + name: "happy path when the user DN has special LDAP search filter characters then they must be properly escaped in the custom group search filter", + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1) - conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(&ldap.SearchResult{ - Entries: []*ldap.Entry{ - { - DN: testGroupSearchResultDNValue1, - Attributes: []*ldap.EntryAttribute{ - ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue1}), + conn.EXPECT().Search(expectedUserSearch(func(r *ldap.SearchRequest) { + r.BaseDN = testUserDNWithSpecialChars + })). + Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testUserDNWithSpecialChars, + Attributes: []*ldap.EntryAttribute{ + { + Name: testUserSearchUsernameAttribute, + Values: []string{testUserSearchResultUsernameAttributeValue}, + }, + { + Name: testUserSearchUIDAttribute, + ByteValues: [][]byte{[]byte(testUserSearchResultUIDAttributeValue)}, + }, + { + Name: pwdLastSetAttribute, + Values: []string{"132801740800000000"}, + ByteValues: [][]byte{[]byte("132801740800000000")}, + }, + }, }, }, - { - DN: testGroupSearchResultDNValue2, - Attributes: []*ldap.EntryAttribute{ - ldap.NewEntryAttribute(testGroupSearchGroupNameAttribute, []string{testGroupSearchResultGroupNameAttributeValue2}), + }, nil).Times(1) + conn.EXPECT().SearchWithPaging(expectedGroupSearch(func(r *ldap.SearchRequest) { + r.Filter = fmt.Sprintf("(some-group-filter=%s-and-more-filter=%s)", testUserDNWithSpecialCharsEscaped, testUserDNWithSpecialCharsEscaped) + }), expectedGroupSearchPageSize).Return(happyPathGroupSearchResult, nil).Times(1) + conn.EXPECT().Close().Times(1) + }, + refreshUserDN: testUserDNWithSpecialChars, + wantGroups: []string{testGroupSearchResultGroupNameAttributeValue1, testGroupSearchResultGroupNameAttributeValue2}, + }, + { + name: "when the user DN has special LDAP search filter characters then they must be properly escaped in the default group search filter", + providerConfig: providerConfig(func(p *ProviderConfig) { + p.GroupSearch.Filter = "" + }), + setupMocks: func(conn *mockldapconn.MockConn) { + conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) + conn.EXPECT().Search(expectedUserSearch(func(r *ldap.SearchRequest) { + r.BaseDN = testUserDNWithSpecialChars + })). + Return(&ldap.SearchResult{ + Entries: []*ldap.Entry{ + { + DN: testUserDNWithSpecialChars, + Attributes: []*ldap.EntryAttribute{ + { + Name: testUserSearchUsernameAttribute, + Values: []string{testUserSearchResultUsernameAttributeValue}, + }, + { + Name: testUserSearchUIDAttribute, + ByteValues: [][]byte{[]byte(testUserSearchResultUIDAttributeValue)}, + }, + { + Name: pwdLastSetAttribute, + Values: []string{"132801740800000000"}, + ByteValues: [][]byte{[]byte("132801740800000000")}, + }, + }, }, }, - }, - Referrals: []string{}, // note that we are not following referrals at this time - Controls: []ldap.Control{}, - }, nil).Times(1) + }, nil).Times(1) + conn.EXPECT().SearchWithPaging(expectedGroupSearch(func(r *ldap.SearchRequest) { + r.Filter = fmt.Sprintf("(member=%s)", testUserDNWithSpecialCharsEscaped) + }), expectedGroupSearchPageSize).Return(happyPathGroupSearchResult, nil).Times(1) conn.EXPECT().Close().Times(1) }, - wantGroups: []string{testGroupSearchResultGroupNameAttributeValue1, testGroupSearchResultGroupNameAttributeValue2}, + refreshUserDN: testUserDNWithSpecialChars, + wantGroups: []string{testGroupSearchResultGroupNameAttributeValue1, testGroupSearchResultGroupNameAttributeValue2}, }, { - name: "happy path where group search returns no groups", - providerConfig: &ProviderConfig{ - Name: "some-provider-name", - Host: testHost, - CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test - ConnectionProtocol: TLS, - BindUsername: testBindUsername, - BindPassword: testBindPassword, - UserSearch: UserSearchConfig{ - Base: testUserSearchBase, - UIDAttribute: testUserSearchUIDAttribute, - UsernameAttribute: testUserSearchUsernameAttribute, - }, - GroupSearch: GroupSearchConfig{ - Base: testGroupSearchBase, - Filter: testGroupSearchFilter, - GroupNameAttribute: testGroupSearchGroupNameAttribute, - }, - RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{ - pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute), - }, - }, + name: "happy path where group search returns no groups", + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1) - conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(&ldap.SearchResult{ + conn.EXPECT().Search(expectedUserSearch(nil)).Return(happyPathUserSearchResult, nil).Times(1) + conn.EXPECT().SearchWithPaging(expectedGroupSearch(nil), expectedGroupSearchPageSize).Return(&ldap.SearchResult{ Entries: []*ldap.Entry{}, Referrals: []string{}, // note that we are not following referrals at this time Controls: []ldap.Control{}, @@ -1386,44 +1455,25 @@ func TestUpstreamRefresh(t *testing.T) { }, { name: "happy path where group search is configured but skipGroupRefresh is set", - providerConfig: &ProviderConfig{ - Name: "some-provider-name", - Host: testHost, - CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test - ConnectionProtocol: TLS, - BindUsername: testBindUsername, - BindPassword: testBindPassword, - UserSearch: UserSearchConfig{ - Base: testUserSearchBase, - UIDAttribute: testUserSearchUIDAttribute, - UsernameAttribute: testUserSearchUsernameAttribute, - }, - GroupSearch: GroupSearchConfig{ - Base: testGroupSearchBase, - Filter: testGroupSearchFilter, - GroupNameAttribute: testGroupSearchGroupNameAttribute, - SkipGroupRefresh: true, - }, - RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{ - pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute), - }, - }, + providerConfig: providerConfig(func(p *ProviderConfig) { + p.GroupSearch.SkipGroupRefresh = true + }), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1) // note that group search is not expected + conn.EXPECT().Search(expectedUserSearch(nil)).Return(happyPathUserSearchResult, nil).Times(1) // note that group search is not expected conn.EXPECT().Close().Times(1) }, wantGroups: nil, // do not update groups }, { name: "error where dial fails", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), dialError: errors.New("some dial error"), wantErr: "error dialing host \"ldap.example.com:8443\": some dial error", }, { name: "error binding", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Return(errors.New("some bind error")).Times(1) conn.EXPECT().Close().Times(1) @@ -1432,10 +1482,10 @@ func TestUpstreamRefresh(t *testing.T) { }, { name: "search result returns no entries", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + conn.EXPECT().Search(expectedUserSearch(nil)).Return(&ldap.SearchResult{ Entries: []*ldap.Entry{}, }, nil).Times(1) conn.EXPECT().Close().Times(1) @@ -1444,20 +1494,20 @@ func TestUpstreamRefresh(t *testing.T) { }, { name: "error searching", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(nil, errors.New("some search error")) + conn.EXPECT().Search(expectedUserSearch(nil)).Return(nil, errors.New("some search error")) conn.EXPECT().Close().Times(1) }, wantErr: "error searching for user \"some-upstream-user-dn\": some search error", }, { name: "search result returns more than one entry", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + conn.EXPECT().Search(expectedUserSearch(nil)).Return(&ldap.SearchResult{ Entries: []*ldap.Entry{ { DN: testUserSearchResultDNValue, @@ -1475,10 +1525,10 @@ func TestUpstreamRefresh(t *testing.T) { }, { name: "search result has wrong uid", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + conn.EXPECT().Search(expectedUserSearch(nil)).Return(&ldap.SearchResult{ Entries: []*ldap.Entry{ { DN: testUserSearchResultDNValue, @@ -1501,10 +1551,10 @@ func TestUpstreamRefresh(t *testing.T) { }, { name: "search result has wrong username", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + conn.EXPECT().Search(expectedUserSearch(nil)).Return(&ldap.SearchResult{ Entries: []*ldap.Entry{ { DN: testUserSearchResultDNValue, @@ -1523,10 +1573,10 @@ func TestUpstreamRefresh(t *testing.T) { }, { name: "search result has no dn", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + conn.EXPECT().Search(expectedUserSearch(nil)).Return(&ldap.SearchResult{ Entries: []*ldap.Entry{ { Attributes: []*ldap.EntryAttribute{ @@ -1548,10 +1598,10 @@ func TestUpstreamRefresh(t *testing.T) { }, { name: "search result has 0 values for username attribute", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + conn.EXPECT().Search(expectedUserSearch(nil)).Return(&ldap.SearchResult{ Entries: []*ldap.Entry{ { DN: testUserSearchResultDNValue, @@ -1574,10 +1624,10 @@ func TestUpstreamRefresh(t *testing.T) { }, { name: "search result has more than one value for username attribute", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + conn.EXPECT().Search(expectedUserSearch(nil)).Return(&ldap.SearchResult{ Entries: []*ldap.Entry{ { DN: testUserSearchResultDNValue, @@ -1600,10 +1650,10 @@ func TestUpstreamRefresh(t *testing.T) { }, { name: "search result has 0 values for uid attribute", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + conn.EXPECT().Search(expectedUserSearch(nil)).Return(&ldap.SearchResult{ Entries: []*ldap.Entry{ { DN: testUserSearchResultDNValue, @@ -1626,10 +1676,10 @@ func TestUpstreamRefresh(t *testing.T) { }, { name: "search result has 2 values for uid attribute", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + conn.EXPECT().Search(expectedUserSearch(nil)).Return(&ldap.SearchResult{ Entries: []*ldap.Entry{ { DN: testUserSearchResultDNValue, @@ -1652,10 +1702,10 @@ func TestUpstreamRefresh(t *testing.T) { }, { name: "search result has a changed pwdLastSet value", - providerConfig: providerConfig, + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(&ldap.SearchResult{ + conn.EXPECT().Search(expectedUserSearch(nil)).Return(&ldap.SearchResult{ Entries: []*ldap.Entry{ { DN: testUserSearchResultDNValue, @@ -1681,32 +1731,12 @@ func TestUpstreamRefresh(t *testing.T) { wantErr: "validation for attribute \"pwdLastSet\" failed during upstream refresh: value for attribute \"pwdLastSet\" has changed since initial value at login", }, { - name: "group search returns an error", - providerConfig: &ProviderConfig{ - Name: "some-provider-name", - Host: testHost, - CABundle: nil, // this field is only used by the production dialer, which is replaced by a mock for this test - ConnectionProtocol: TLS, - BindUsername: testBindUsername, - BindPassword: testBindPassword, - UserSearch: UserSearchConfig{ - Base: testUserSearchBase, - UIDAttribute: testUserSearchUIDAttribute, - UsernameAttribute: testUserSearchUsernameAttribute, - }, - GroupSearch: GroupSearchConfig{ - Base: testGroupSearchBase, - Filter: testGroupSearchFilter, - GroupNameAttribute: testGroupSearchGroupNameAttribute, - }, - RefreshAttributeChecks: map[string]func(*ldap.Entry, provider.StoredRefreshAttributes) error{ - pwdLastSetAttribute: AttributeUnchangedSinceLogin(pwdLastSetAttribute), - }, - }, + name: "group search returns an error", + providerConfig: providerConfig(nil), setupMocks: func(conn *mockldapconn.MockConn) { conn.EXPECT().Bind(testBindUsername, testBindPassword).Times(1) - conn.EXPECT().Search(expectedUserSearch).Return(happyPathUserSearchResult, nil).Times(1) - conn.EXPECT().SearchWithPaging(expectedGroupSearch, expectedGroupSearchPageSize).Return(nil, errors.New("some search error")).Times(1) + conn.EXPECT().Search(expectedUserSearch(nil)).Return(happyPathUserSearchResult, nil).Times(1) + conn.EXPECT().SearchWithPaging(expectedGroupSearch(nil), expectedGroupSearchPageSize).Return(nil, errors.New("some search error")).Times(1) conn.EXPECT().Close().Times(1) }, wantErr: "error searching for group memberships for user with DN \"some-upstream-user-dn\": some search error", @@ -1735,13 +1765,17 @@ func TestUpstreamRefresh(t *testing.T) { return conn, nil }) + if tt.refreshUserDN == "" { + tt.refreshUserDN = testUserSearchResultDNValue // default for all tests + } + initialPwdLastSetEncoded := base64.RawURLEncoding.EncodeToString([]byte("132801740800000000")) ldapProvider := New(*tt.providerConfig) subject := "ldaps://ldap.example.com:8443?base=some-upstream-user-base-dn&sub=c29tZS11cHN0cmVhbS11aWQtdmFsdWU" groups, err := ldapProvider.PerformRefresh(context.Background(), provider.StoredRefreshAttributes{ Username: testUserSearchResultUsernameAttributeValue, Subject: subject, - DN: testUserSearchResultDNValue, + DN: tt.refreshUserDN, AdditionalAttributes: map[string]string{pwdLastSetAttribute: initialPwdLastSetEncoded}, }) if tt.wantErr != "" {