Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Find Sources and fix bug when view a user who belongs to an unactive auth source #27798

Merged
merged 8 commits into from
Nov 3, 2023
2 changes: 1 addition & 1 deletion cmd/admin_auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ func runListAuth(c *cli.Context) error {
return err
}

authSources, err := auth_model.Sources(ctx)
authSources, err := auth_model.FindSources(ctx, auth_model.FindSourcesOptions{})
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion models/activities/statistic.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) {
stats.Counter.Follow, _ = e.Count(new(user_model.Follow))
stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror))
stats.Counter.Release, _ = e.Count(new(repo_model.Release))
stats.Counter.AuthSource = auth.CountSources(ctx)
stats.Counter.AuthSource = auth.CountSources(ctx, auth.FindSourcesOptions{})
stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook))
stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone))
stats.Counter.Label, _ = e.Count(new(issues_model.Label))
Expand Down
9 changes: 0 additions & 9 deletions models/auth/oauth2.go
Original file line number Diff line number Diff line change
Expand Up @@ -631,15 +631,6 @@ func (err ErrOAuthApplicationNotFound) Unwrap() error {
return util.ErrNotExist
}

// GetActiveOAuth2ProviderSources returns all actived LoginOAuth2 sources
func GetActiveOAuth2ProviderSources(ctx context.Context) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil {
return nil, err
}
return sources, nil
}

// GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name
func GetActiveOAuth2SourceByName(ctx context.Context, name string) (*Source, error) {
authSource := new(Source)
Expand Down
49 changes: 21 additions & 28 deletions models/auth/source.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
"code.gitea.io/gitea/modules/timeutil"
"code.gitea.io/gitea/modules/util"

"xorm.io/builder"
"xorm.io/xorm"
"xorm.io/xorm/convert"
)
Expand Down Expand Up @@ -240,37 +241,26 @@ func CreateSource(ctx context.Context, source *Source) error {
return err
}

// Sources returns a slice of all login sources found in DB.
func Sources(ctx context.Context) ([]*Source, error) {
auths := make([]*Source, 0, 6)
return auths, db.GetEngine(ctx).Find(&auths)
type FindSourcesOptions struct {
IsActive util.OptionalBool
LoginType Type
}

// SourcesByType returns all sources of the specified type
func SourcesByType(ctx context.Context, loginType Type) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(ctx).Where("type = ?", loginType).Find(&sources); err != nil {
return nil, err
func (opts FindSourcesOptions) ToConds() builder.Cond {
conds := builder.NewCond()
if !opts.IsActive.IsNone() {
conds = conds.And(builder.Eq{"is_active": opts.IsActive.IsTrue()})
}
return sources, nil
}

// AllActiveSources returns all active sources
func AllActiveSources(ctx context.Context) ([]*Source, error) {
sources := make([]*Source, 0, 5)
if err := db.GetEngine(ctx).Where("is_active = ?", true).Find(&sources); err != nil {
return nil, err
if opts.LoginType != NoType {
conds = conds.And(builder.Eq{"`type`": opts.LoginType})
}
return sources, nil
return conds
}

// ActiveSources returns all active sources of the specified type
func ActiveSources(ctx context.Context, tp Type) ([]*Source, error) {
sources := make([]*Source, 0, 1)
if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil {
return nil, err
}
return sources, nil
// FindSources returns a slice of login sources found in DB according to given conditions.
func FindSources(ctx context.Context, opts FindSourcesOptions) ([]*Source, error) {
auths := make([]*Source, 0, 6)
return auths, db.GetEngine(ctx).Where(opts.ToConds()).Find(&auths)
}

// IsSSPIEnabled returns true if there is at least one activated login
Expand All @@ -279,7 +269,10 @@ func IsSSPIEnabled(ctx context.Context) bool {
if !db.HasEngine {
return false
}
sources, err := ActiveSources(ctx, SSPI)
sources, err := FindSources(ctx, FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
LoginType: SSPI,
})
if err != nil {
log.Error("ActiveSources: %v", err)
return false
Expand Down Expand Up @@ -354,8 +347,8 @@ func UpdateSource(ctx context.Context, source *Source) error {
}

// CountSources returns number of login sources.
func CountSources(ctx context.Context) int64 {
count, _ := db.GetEngine(ctx).Count(new(Source))
func CountSources(ctx context.Context, opts FindSourcesOptions) int64 {
count, _ := db.GetEngine(ctx).Where(opts.ToConds()).Count(new(Source))
return count
}

Expand Down
6 changes: 3 additions & 3 deletions routers/web/admin/auths.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ func Authentications(ctx *context.Context) {
ctx.Data["PageIsAdminAuthentications"] = true

var err error
ctx.Data["Sources"], err = auth.Sources(ctx)
ctx.Data["Sources"], err = auth.FindSources(ctx, auth.FindSourcesOptions{})
if err != nil {
ctx.ServerError("auth.Sources", err)
return
}

ctx.Data["Total"] = auth.CountSources(ctx)
ctx.Data["Total"] = auth.CountSources(ctx, auth.FindSourcesOptions{})
ctx.HTML(http.StatusOK, tplAuths)
}

Expand Down Expand Up @@ -284,7 +284,7 @@ func NewAuthSourcePost(ctx *context.Context) {
ctx.RenderWithErr(err.Error(), tplAuthNew, form)
return
}
existing, err := auth.SourcesByType(ctx, auth.SSPI)
existing, err := auth.FindSources(ctx, auth.FindSourcesOptions{LoginType: auth.SSPI})
if err != nil || len(existing) > 0 {
ctx.Data["Err_Type"] = true
ctx.RenderWithErr(ctx.Tr("admin.auths.login_source_of_type_exist"), tplAuthNew, form)
Expand Down
10 changes: 7 additions & 3 deletions routers/web/admin/users.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ func NewUser(ctx *context.Context) {

ctx.Data["login_type"] = "0-0"

sources, err := auth.Sources(ctx)
sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
})
if err != nil {
ctx.ServerError("auth.Sources", err)
return
Expand All @@ -109,7 +111,9 @@ func NewUserPost(ctx *context.Context) {
ctx.Data["DefaultUserVisibilityMode"] = setting.Service.DefaultUserVisibilityMode
ctx.Data["AllowedUserVisibilityModes"] = setting.Service.AllowedUserVisibilityModesSlice.ToVisibleTypeSlice()

sources, err := auth.Sources(ctx)
sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
})
if err != nil {
ctx.ServerError("auth.Sources", err)
return
Expand Down Expand Up @@ -230,7 +234,7 @@ func prepareUserInfo(ctx *context.Context) *user_model.User {
ctx.Data["LoginSource"] = &auth.Source{}
}

sources, err := auth.Sources(ctx)
sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{})
if err != nil {
ctx.ServerError("auth.Sources", err)
return nil
Expand Down
26 changes: 24 additions & 2 deletions routers/web/user/setting/security/security.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@ package security

import (
"net/http"
"sort"

auth_model "code.gitea.io/gitea/models/auth"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/base"
"code.gitea.io/gitea/modules/context"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/services/auth/source/oauth2"
)

Expand Down Expand Up @@ -105,11 +107,31 @@ func loadSecurityData(ctx *context.Context) {
}
ctx.Data["AccountLinks"] = sources

orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx)
authSources, err := auth_model.FindSources(ctx, auth_model.FindSourcesOptions{
IsActive: util.OptionalBoolNone,
LoginType: auth_model.OAuth2,
})
if err != nil {
ctx.ServerError("GetActiveOAuth2Providers", err)
ctx.ServerError("GetOAuth2ProvidersMap", err)
return
}

var orderedOAuth2Names []string
oauth2Providers := make(map[string]oauth2.Provider)
for _, source := range authSources {
provider, err := oauth2.CreateProviderFromSource(source)
if err != nil {
ctx.ServerError("CreateProviderFromSource", err)
return
}
oauth2Providers[source.Name] = provider
if source.IsActive {
orderedOAuth2Names = append(orderedOAuth2Names, source.Name)
}
}

sort.Strings(orderedOAuth2Names)
lunny marked this conversation as resolved.
Show resolved Hide resolved

ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names
ctx.Data["OAuth2Providers"] = oauth2Providers

Expand Down
5 changes: 4 additions & 1 deletion services/auth/signin.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"code.gitea.io/gitea/models/db"
user_model "code.gitea.io/gitea/models/user"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/util"
"code.gitea.io/gitea/services/auth/source/oauth2"
"code.gitea.io/gitea/services/auth/source/smtp"

Expand Down Expand Up @@ -85,7 +86,9 @@ func UserSignIn(ctx context.Context, username, password string) (*user_model.Use
}
}

sources, err := auth.AllActiveSources(ctx)
sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
})
if err != nil {
return nil, nil, err
}
Expand Down
9 changes: 8 additions & 1 deletion services/auth/source/oauth2/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"

"github.com/google/uuid"
"github.com/gorilla/sessions"
Expand Down Expand Up @@ -63,7 +64,13 @@ func ResetOAuth2(ctx context.Context) error {

// initOAuth2Sources is used to load and register all active OAuth2 providers
func initOAuth2Sources(ctx context.Context) error {
authSources, _ := auth.GetActiveOAuth2ProviderSources(ctx)
authSources, err := auth.FindSources(ctx, auth.FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
LoginType: auth.OAuth2,
})
if err != nil {
return err
}
for _, source := range authSources {
oauth2Source, ok := source.Cfg.(*Source)
if !ok {
Expand Down
39 changes: 27 additions & 12 deletions services/auth/source/oauth2/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"code.gitea.io/gitea/models/auth"
"code.gitea.io/gitea/modules/log"
"code.gitea.io/gitea/modules/setting"
"code.gitea.io/gitea/modules/util"

"github.com/markbates/goth"
)
Expand Down Expand Up @@ -95,27 +96,33 @@ func GetOAuth2Providers() []Provider {
return providers
}

// GetActiveOAuth2Providers returns the map of configured active OAuth2 providers
// key is used as technical name (like in the callbackURL)
// values to display
func GetActiveOAuth2Providers(ctx context.Context) ([]string, map[string]Provider, error) {
// Maybe also separate used and unused providers so we can force the registration of only 1 active provider for each type
func CreateProviderFromSource(source *auth.Source) (Provider, error) {
oauth2Cfg, ok := source.Cfg.(*Source)
if !ok {
return nil, fmt.Errorf("invalid OAuth2 source config: %v", oauth2Cfg)
}
gothProv := gothProviders[oauth2Cfg.Provider]
return &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL}, nil
}

authSources, err := auth.GetActiveOAuth2ProviderSources(ctx)
// GetOAuth2ProvidersMap returns the map of configured OAuth2 providers
func GetOAuth2ProvidersMap(ctx context.Context, isActive util.OptionalBool) ([]string, map[string]Provider, error) {
authSources, err := auth.FindSources(ctx, auth.FindSourcesOptions{
IsActive: isActive,
LoginType: auth.OAuth2,
})
if err != nil {
return nil, nil, err
}

var orderedKeys []string
providers := make(map[string]Provider)
for _, source := range authSources {
oauth2Cfg, ok := source.Cfg.(*Source)
if !ok {
log.Error("Invalid OAuth2 source config: %v", oauth2Cfg)
continue
provider, err := CreateProviderFromSource(source)
if err != nil {
return nil, nil, err
}
gothProv := gothProviders[oauth2Cfg.Provider]
providers[source.Name] = &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL}
providers[source.Name] = provider
orderedKeys = append(orderedKeys, source.Name)
}

Expand All @@ -124,6 +131,14 @@ func GetActiveOAuth2Providers(ctx context.Context) ([]string, map[string]Provide
return orderedKeys, providers, nil
}

// GetActiveOAuth2Providers returns the map of configured active OAuth2 providers
// key is used as technical name (like in the callbackURL)
// values to display
func GetActiveOAuth2Providers(ctx context.Context) ([]string, map[string]Provider, error) {
// Maybe also separate used and unused providers so we can force the registration of only 1 active provider for each type
return GetOAuth2ProvidersMap(ctx, util.OptionalBoolTrue)
}

// RegisterProviderWithGothic register a OAuth2 provider in goth lib
func RegisterProviderWithGothic(providerName string, source *Source) error {
provider, err := createProvider(providerName, source)
Expand Down
5 changes: 4 additions & 1 deletion services/auth/sspi.go
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,10 @@ func (s *SSPI) Verify(req *http.Request, w http.ResponseWriter, store DataStore,

// getConfig retrieves the SSPI configuration from login sources
func (s *SSPI) getConfig(ctx context.Context) (*sspi.Source, error) {
sources, err := auth.ActiveSources(ctx, auth.SSPI)
sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{
IsActive: util.OptionalBoolTrue,
LoginType: auth.SSPI,
})
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion services/auth/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (
func SyncExternalUsers(ctx context.Context, updateExisting bool) error {
log.Trace("Doing: SyncExternalUsers")

ls, err := auth.Sources(ctx)
ls, err := auth.FindSources(ctx, auth.FindSourcesOptions{})
if err != nil {
log.Error("SyncExternalUsers: %v", err)
return err
Expand Down