From 1bf5527eac6b947010c8faf408f6747de2a2384f Mon Sep 17 00:00:00 2001 From: Lunny Xiao Date: Fri, 3 Nov 2023 09:41:00 +0800 Subject: [PATCH] Refactor Find Sources and fix bug when view a user who belongs to an unactive auth source (#27798) The steps to reproduce it. First, create a new oauth2 source. Then, a user login with this oauth2 source. Disable the oauth2 source. Visit users -> settings -> security, 500 will be displayed. This is because this page only load active Oauth2 sources but not all Oauth2 sources. --- cmd/admin_auth.go | 2 +- models/activities/statistic.go | 2 +- models/auth/oauth2.go | 9 ---- models/auth/source.go | 49 ++++++++----------- routers/web/admin/auths.go | 14 +++--- routers/web/admin/users.go | 10 ++-- routers/web/auth/auth.go | 12 ++--- routers/web/user/setting/security/security.go | 26 +++++++++- services/auth/signin.go | 5 +- services/auth/source/oauth2/init.go | 9 +++- services/auth/source/oauth2/providers.go | 47 ++++++++++-------- services/auth/sspi.go | 5 +- services/auth/sync.go | 2 +- templates/user/auth/signin_inner.tmpl | 7 ++- templates/user/auth/signup_inner.tmpl | 7 ++- 15 files changed, 115 insertions(+), 91 deletions(-) diff --git a/cmd/admin_auth.go b/cmd/admin_auth.go index 3b308d77f7987..014ddf329f94d 100644 --- a/cmd/admin_auth.go +++ b/cmd/admin_auth.go @@ -62,7 +62,7 @@ func runListAuth(c *cli.Context) error { return err } - authSources, err := auth_model.Sources(ctx) + authSources, err := auth_model.FindSources(ctx, auth_model.FindSourcesOptions{}) if err != nil { return err } diff --git a/models/activities/statistic.go b/models/activities/statistic.go index 009c8c5ab474e..e9dab6fc10b6b 100644 --- a/models/activities/statistic.go +++ b/models/activities/statistic.go @@ -102,7 +102,7 @@ func GetStatistic(ctx context.Context) (stats Statistic) { stats.Counter.Follow, _ = e.Count(new(user_model.Follow)) stats.Counter.Mirror, _ = e.Count(new(repo_model.Mirror)) stats.Counter.Release, _ = e.Count(new(repo_model.Release)) - stats.Counter.AuthSource = auth.CountSources(ctx) + stats.Counter.AuthSource = auth.CountSources(ctx, auth.FindSourcesOptions{}) stats.Counter.Webhook, _ = e.Count(new(webhook.Webhook)) stats.Counter.Milestone, _ = e.Count(new(issues_model.Milestone)) stats.Counter.Label, _ = e.Count(new(issues_model.Label)) diff --git a/models/auth/oauth2.go b/models/auth/oauth2.go index d73ad6965d2f0..76a4e9d835bcd 100644 --- a/models/auth/oauth2.go +++ b/models/auth/oauth2.go @@ -631,15 +631,6 @@ func (err ErrOAuthApplicationNotFound) Unwrap() error { return util.ErrNotExist } -// GetActiveOAuth2ProviderSources returns all actived LoginOAuth2 sources -func GetActiveOAuth2ProviderSources(ctx context.Context) ([]*Source, error) { - sources := make([]*Source, 0, 1) - if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, OAuth2).Find(&sources); err != nil { - return nil, err - } - return sources, nil -} - // GetActiveOAuth2SourceByName returns a OAuth2 AuthSource based on the given name func GetActiveOAuth2SourceByName(ctx context.Context, name string) (*Source, error) { authSource := new(Source) diff --git a/models/auth/source.go b/models/auth/source.go index 0f57d1702a774..b3f3262cc206c 100644 --- a/models/auth/source.go +++ b/models/auth/source.go @@ -14,6 +14,7 @@ import ( "code.gitea.io/gitea/modules/timeutil" "code.gitea.io/gitea/modules/util" + "xorm.io/builder" "xorm.io/xorm" "xorm.io/xorm/convert" ) @@ -240,37 +241,26 @@ func CreateSource(ctx context.Context, source *Source) error { return err } -// Sources returns a slice of all login sources found in DB. -func Sources(ctx context.Context) ([]*Source, error) { - auths := make([]*Source, 0, 6) - return auths, db.GetEngine(ctx).Find(&auths) +type FindSourcesOptions struct { + IsActive util.OptionalBool + LoginType Type } -// SourcesByType returns all sources of the specified type -func SourcesByType(ctx context.Context, loginType Type) ([]*Source, error) { - sources := make([]*Source, 0, 1) - if err := db.GetEngine(ctx).Where("type = ?", loginType).Find(&sources); err != nil { - return nil, err +func (opts FindSourcesOptions) ToConds() builder.Cond { + conds := builder.NewCond() + if !opts.IsActive.IsNone() { + conds = conds.And(builder.Eq{"is_active": opts.IsActive.IsTrue()}) } - return sources, nil -} - -// AllActiveSources returns all active sources -func AllActiveSources(ctx context.Context) ([]*Source, error) { - sources := make([]*Source, 0, 5) - if err := db.GetEngine(ctx).Where("is_active = ?", true).Find(&sources); err != nil { - return nil, err + if opts.LoginType != NoType { + conds = conds.And(builder.Eq{"`type`": opts.LoginType}) } - return sources, nil + return conds } -// ActiveSources returns all active sources of the specified type -func ActiveSources(ctx context.Context, tp Type) ([]*Source, error) { - sources := make([]*Source, 0, 1) - if err := db.GetEngine(ctx).Where("is_active = ? and type = ?", true, tp).Find(&sources); err != nil { - return nil, err - } - return sources, nil +// FindSources returns a slice of login sources found in DB according to given conditions. +func FindSources(ctx context.Context, opts FindSourcesOptions) ([]*Source, error) { + auths := make([]*Source, 0, 6) + return auths, db.GetEngine(ctx).Where(opts.ToConds()).Find(&auths) } // IsSSPIEnabled returns true if there is at least one activated login @@ -279,7 +269,10 @@ func IsSSPIEnabled(ctx context.Context) bool { if !db.HasEngine { return false } - sources, err := ActiveSources(ctx, SSPI) + sources, err := FindSources(ctx, FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + LoginType: SSPI, + }) if err != nil { log.Error("ActiveSources: %v", err) return false @@ -354,8 +347,8 @@ func UpdateSource(ctx context.Context, source *Source) error { } // CountSources returns number of login sources. -func CountSources(ctx context.Context) int64 { - count, _ := db.GetEngine(ctx).Count(new(Source)) +func CountSources(ctx context.Context, opts FindSourcesOptions) int64 { + count, _ := db.GetEngine(ctx).Where(opts.ToConds()).Count(new(Source)) return count } diff --git a/routers/web/admin/auths.go b/routers/web/admin/auths.go index da91e31efe11c..23946d64afa33 100644 --- a/routers/web/admin/auths.go +++ b/routers/web/admin/auths.go @@ -48,13 +48,13 @@ func Authentications(ctx *context.Context) { ctx.Data["PageIsAdminAuthentications"] = true var err error - ctx.Data["Sources"], err = auth.Sources(ctx) + ctx.Data["Sources"], err = auth.FindSources(ctx, auth.FindSourcesOptions{}) if err != nil { ctx.ServerError("auth.Sources", err) return } - ctx.Data["Total"] = auth.CountSources(ctx) + ctx.Data["Total"] = auth.CountSources(ctx, auth.FindSourcesOptions{}) ctx.HTML(http.StatusOK, tplAuths) } @@ -99,7 +99,7 @@ func NewAuthSource(ctx *context.Context) { ctx.Data["AuthSources"] = authSources ctx.Data["SecurityProtocols"] = securityProtocols ctx.Data["SMTPAuths"] = smtp.Authenticators - oauth2providers := oauth2.GetOAuth2Providers() + oauth2providers := oauth2.GetSupportedOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers ctx.Data["SSPIAutoCreateUsers"] = true @@ -242,7 +242,7 @@ func NewAuthSourcePost(ctx *context.Context) { ctx.Data["AuthSources"] = authSources ctx.Data["SecurityProtocols"] = securityProtocols ctx.Data["SMTPAuths"] = smtp.Authenticators - oauth2providers := oauth2.GetOAuth2Providers() + oauth2providers := oauth2.GetSupportedOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers ctx.Data["SSPIAutoCreateUsers"] = true @@ -284,7 +284,7 @@ func NewAuthSourcePost(ctx *context.Context) { ctx.RenderWithErr(err.Error(), tplAuthNew, form) return } - existing, err := auth.SourcesByType(ctx, auth.SSPI) + existing, err := auth.FindSources(ctx, auth.FindSourcesOptions{LoginType: auth.SSPI}) if err != nil || len(existing) > 0 { ctx.Data["Err_Type"] = true ctx.RenderWithErr(ctx.Tr("admin.auths.login_source_of_type_exist"), tplAuthNew, form) @@ -334,7 +334,7 @@ func EditAuthSource(ctx *context.Context) { ctx.Data["SecurityProtocols"] = securityProtocols ctx.Data["SMTPAuths"] = smtp.Authenticators - oauth2providers := oauth2.GetOAuth2Providers() + oauth2providers := oauth2.GetSupportedOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers source, err := auth.GetSourceByID(ctx, ctx.ParamsInt64(":authid")) @@ -368,7 +368,7 @@ func EditAuthSourcePost(ctx *context.Context) { ctx.Data["PageIsAdminAuthentications"] = true ctx.Data["SMTPAuths"] = smtp.Authenticators - oauth2providers := oauth2.GetOAuth2Providers() + oauth2providers := oauth2.GetSupportedOAuth2Providers() ctx.Data["OAuth2Providers"] = oauth2providers source, err := auth.GetSourceByID(ctx, ctx.ParamsInt64(":authid")) diff --git a/routers/web/admin/users.go b/routers/web/admin/users.go index 91a578fb55482..630d739836beb 100644 --- a/routers/web/admin/users.go +++ b/routers/web/admin/users.go @@ -90,7 +90,9 @@ func NewUser(ctx *context.Context) { ctx.Data["login_type"] = "0-0" - sources, err := auth.Sources(ctx) + sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + }) if err != nil { ctx.ServerError("auth.Sources", err) return @@ -109,7 +111,9 @@ func NewUserPost(ctx *context.Context) { ctx.Data["DefaultUserVisibilityMode"] = setting.Service.DefaultUserVisibilityMode ctx.Data["AllowedUserVisibilityModes"] = setting.Service.AllowedUserVisibilityModesSlice.ToVisibleTypeSlice() - sources, err := auth.Sources(ctx) + sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + }) if err != nil { ctx.ServerError("auth.Sources", err) return @@ -230,7 +234,7 @@ func prepareUserInfo(ctx *context.Context) *user_model.User { ctx.Data["LoginSource"] = &auth.Source{} } - sources, err := auth.Sources(ctx) + sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{}) if err != nil { ctx.ServerError("auth.Sources", err) return nil diff --git a/routers/web/auth/auth.go b/routers/web/auth/auth.go index e27307ef1afc6..0ea91fc759a9a 100644 --- a/routers/web/auth/auth.go +++ b/routers/web/auth/auth.go @@ -160,12 +160,11 @@ func SignIn(ctx *context.Context) { return } - orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx) + oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue) if err != nil { ctx.ServerError("UserSignIn", err) return } - ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names ctx.Data["OAuth2Providers"] = oauth2Providers ctx.Data["Title"] = ctx.Tr("sign_in") ctx.Data["SignInLink"] = setting.AppSubURL + "/user/login" @@ -184,12 +183,11 @@ func SignIn(ctx *context.Context) { func SignInPost(ctx *context.Context) { ctx.Data["Title"] = ctx.Tr("sign_in") - orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx) + oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue) if err != nil { ctx.ServerError("UserSignIn", err) return } - ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names ctx.Data["OAuth2Providers"] = oauth2Providers ctx.Data["Title"] = ctx.Tr("sign_in") ctx.Data["SignInLink"] = setting.AppSubURL + "/user/login" @@ -408,13 +406,12 @@ func SignUp(ctx *context.Context) { ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/sign_up" - orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx) + oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue) if err != nil { ctx.ServerError("UserSignUp", err) return } - ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names ctx.Data["OAuth2Providers"] = oauth2Providers context.SetCaptchaData(ctx) @@ -438,13 +435,12 @@ func SignUpPost(ctx *context.Context) { ctx.Data["SignUpLink"] = setting.AppSubURL + "/user/sign_up" - orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx) + oauth2Providers, err := oauth2.GetOAuth2Providers(ctx, util.OptionalBoolTrue) if err != nil { ctx.ServerError("UserSignUp", err) return } - ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names ctx.Data["OAuth2Providers"] = oauth2Providers context.SetCaptchaData(ctx) diff --git a/routers/web/user/setting/security/security.go b/routers/web/user/setting/security/security.go index 58c637e2b3365..e64901ae728e7 100644 --- a/routers/web/user/setting/security/security.go +++ b/routers/web/user/setting/security/security.go @@ -6,12 +6,14 @@ package security import ( "net/http" + "sort" auth_model "code.gitea.io/gitea/models/auth" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/base" "code.gitea.io/gitea/modules/context" "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" "code.gitea.io/gitea/services/auth/source/oauth2" ) @@ -105,11 +107,31 @@ func loadSecurityData(ctx *context.Context) { } ctx.Data["AccountLinks"] = sources - orderedOAuth2Names, oauth2Providers, err := oauth2.GetActiveOAuth2Providers(ctx) + authSources, err := auth_model.FindSources(ctx, auth_model.FindSourcesOptions{ + IsActive: util.OptionalBoolNone, + LoginType: auth_model.OAuth2, + }) if err != nil { - ctx.ServerError("GetActiveOAuth2Providers", err) + ctx.ServerError("FindSources", err) return } + + var orderedOAuth2Names []string + oauth2Providers := make(map[string]oauth2.Provider) + for _, source := range authSources { + provider, err := oauth2.CreateProviderFromSource(source) + if err != nil { + ctx.ServerError("CreateProviderFromSource", err) + return + } + oauth2Providers[source.Name] = provider + if source.IsActive { + orderedOAuth2Names = append(orderedOAuth2Names, source.Name) + } + } + + sort.Strings(orderedOAuth2Names) + ctx.Data["OrderedOAuth2Names"] = orderedOAuth2Names ctx.Data["OAuth2Providers"] = oauth2Providers diff --git a/services/auth/signin.go b/services/auth/signin.go index 5fdf6d2bd7ab2..2e534536817e4 100644 --- a/services/auth/signin.go +++ b/services/auth/signin.go @@ -11,6 +11,7 @@ import ( "code.gitea.io/gitea/models/db" user_model "code.gitea.io/gitea/models/user" "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/util" "code.gitea.io/gitea/services/auth/source/oauth2" "code.gitea.io/gitea/services/auth/source/smtp" @@ -85,7 +86,9 @@ func UserSignIn(ctx context.Context, username, password string) (*user_model.Use } } - sources, err := auth.AllActiveSources(ctx) + sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + }) if err != nil { return nil, nil, err } diff --git a/services/auth/source/oauth2/init.go b/services/auth/source/oauth2/init.go index cfaddaa35d55c..0ebbdaebd411f 100644 --- a/services/auth/source/oauth2/init.go +++ b/services/auth/source/oauth2/init.go @@ -12,6 +12,7 @@ import ( "code.gitea.io/gitea/models/auth" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" "github.com/google/uuid" "github.com/gorilla/sessions" @@ -63,7 +64,13 @@ func ResetOAuth2(ctx context.Context) error { // initOAuth2Sources is used to load and register all active OAuth2 providers func initOAuth2Sources(ctx context.Context) error { - authSources, _ := auth.GetActiveOAuth2ProviderSources(ctx) + authSources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + LoginType: auth.OAuth2, + }) + if err != nil { + return err + } for _, source := range authSources { oauth2Source, ok := source.Cfg.(*Source) if !ok { diff --git a/services/auth/source/oauth2/providers.go b/services/auth/source/oauth2/providers.go index cd158614a2e4e..3b45b252f7099 100644 --- a/services/auth/source/oauth2/providers.go +++ b/services/auth/source/oauth2/providers.go @@ -15,6 +15,7 @@ import ( "code.gitea.io/gitea/models/auth" "code.gitea.io/gitea/modules/log" "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/modules/util" "github.com/markbates/goth" ) @@ -80,10 +81,10 @@ func RegisterGothProvider(provider GothProvider) { gothProviders[provider.Name()] = provider } -// GetOAuth2Providers returns the map of unconfigured OAuth2 providers +// GetSupportedOAuth2Providers returns the map of unconfigured OAuth2 providers // key is used as technical name (like in the callbackURL) // values to display -func GetOAuth2Providers() []Provider { +func GetSupportedOAuth2Providers() []Provider { providers := make([]Provider, 0, len(gothProviders)) for _, provider := range gothProviders { @@ -95,33 +96,39 @@ func GetOAuth2Providers() []Provider { return providers } -// GetActiveOAuth2Providers returns the map of configured active OAuth2 providers -// key is used as technical name (like in the callbackURL) -// values to display -func GetActiveOAuth2Providers(ctx context.Context) ([]string, map[string]Provider, error) { - // Maybe also separate used and unused providers so we can force the registration of only 1 active provider for each type +func CreateProviderFromSource(source *auth.Source) (Provider, error) { + oauth2Cfg, ok := source.Cfg.(*Source) + if !ok { + return nil, fmt.Errorf("invalid OAuth2 source config: %v", oauth2Cfg) + } + gothProv := gothProviders[oauth2Cfg.Provider] + return &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL}, nil +} - authSources, err := auth.GetActiveOAuth2ProviderSources(ctx) +// GetOAuth2Providers returns the list of configured OAuth2 providers +func GetOAuth2Providers(ctx context.Context, isActive util.OptionalBool) ([]Provider, error) { + authSources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: isActive, + LoginType: auth.OAuth2, + }) if err != nil { - return nil, nil, err + return nil, err } - var orderedKeys []string - providers := make(map[string]Provider) + providers := make([]Provider, 0, len(authSources)) for _, source := range authSources { - oauth2Cfg, ok := source.Cfg.(*Source) - if !ok { - log.Error("Invalid OAuth2 source config: %v", oauth2Cfg) - continue + provider, err := CreateProviderFromSource(source) + if err != nil { + return nil, err } - gothProv := gothProviders[oauth2Cfg.Provider] - providers[source.Name] = &AuthSourceProvider{GothProvider: gothProv, sourceName: source.Name, iconURL: oauth2Cfg.IconURL} - orderedKeys = append(orderedKeys, source.Name) + providers = append(providers, provider) } - sort.Strings(orderedKeys) + sort.Slice(providers, func(i, j int) bool { + return providers[i].Name() < providers[j].Name() + }) - return orderedKeys, providers, nil + return providers, nil } // RegisterProviderWithGothic register a OAuth2 provider in goth lib diff --git a/services/auth/sspi.go b/services/auth/sspi.go index 573d94b42c2c0..bc8ec948f29cd 100644 --- a/services/auth/sspi.go +++ b/services/auth/sspi.go @@ -130,7 +130,10 @@ func (s *SSPI) Verify(req *http.Request, w http.ResponseWriter, store DataStore, // getConfig retrieves the SSPI configuration from login sources func (s *SSPI) getConfig(ctx context.Context) (*sspi.Source, error) { - sources, err := auth.ActiveSources(ctx, auth.SSPI) + sources, err := auth.FindSources(ctx, auth.FindSourcesOptions{ + IsActive: util.OptionalBoolTrue, + LoginType: auth.SSPI, + }) if err != nil { return nil, err } diff --git a/services/auth/sync.go b/services/auth/sync.go index 25b9460b9921f..11a59d41ae1b4 100644 --- a/services/auth/sync.go +++ b/services/auth/sync.go @@ -15,7 +15,7 @@ import ( func SyncExternalUsers(ctx context.Context, updateExisting bool) error { log.Trace("Doing: SyncExternalUsers") - ls, err := auth.Sources(ctx) + ls, err := auth.FindSources(ctx, auth.FindSourcesOptions{}) if err != nil { log.Error("SyncExternalUsers: %v", err) return err diff --git a/templates/user/auth/signin_inner.tmpl b/templates/user/auth/signin_inner.tmpl index f38b0a26087f5..7f744b24d87bf 100644 --- a/templates/user/auth/signin_inner.tmpl +++ b/templates/user/auth/signin_inner.tmpl @@ -52,16 +52,15 @@ {{end}} - {{if and .OrderedOAuth2Names .OAuth2Providers}} + {{if .OAuth2Providers}}
{{ctx.Locale.Tr "sign_in_or"}}
- {{range $key := .OrderedOAuth2Names}} - {{$provider := index $.OAuth2Providers $key}} - diff --git a/templates/user/auth/signup_inner.tmpl b/templates/user/auth/signup_inner.tmpl index 068ccbc6182c1..c75e33a18a0e9 100644 --- a/templates/user/auth/signup_inner.tmpl +++ b/templates/user/auth/signup_inner.tmpl @@ -56,16 +56,15 @@ {{end}} {{end}} - {{if and .OrderedOAuth2Names .OAuth2Providers}} + {{if .OAuth2Providers}}
{{ctx.Locale.Tr "sign_in_or"}}