Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: refactor auth package #7110

Merged
merged 13 commits into from
Feb 17, 2025
2 changes: 2 additions & 0 deletions ee/query-service/app/api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
basemodel "go.signoz.io/signoz/pkg/query-service/model"
rules "go.signoz.io/signoz/pkg/query-service/rules"
"go.signoz.io/signoz/pkg/query-service/version"
"go.signoz.io/signoz/pkg/types/authtypes"
)

type APIHandlerOptions struct {
Expand All @@ -41,6 +42,7 @@ type APIHandlerOptions struct {
FluxInterval time.Duration
UseLogsNewSchema bool
UseTraceNewSchema bool
JWT *authtypes.JWT
}

type APIHandler struct {
Expand Down
6 changes: 3 additions & 3 deletions ee/query-service/app/api/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func (ah *APIHandler) loginUser(w http.ResponseWriter, r *http.Request) {
}

// if all looks good, call auth
resp, err := baseauth.Login(ctx, &req)
resp, err := baseauth.Login(ctx, &req, ah.opts.JWT)
if ah.HandleError(w, err, http.StatusUnauthorized) {
return
}
Expand Down Expand Up @@ -253,7 +253,7 @@ func (ah *APIHandler) receiveGoogleAuth(w http.ResponseWriter, r *http.Request)
return
}

nextPage, err := ah.AppDao().PrepareSsoRedirect(ctx, redirectUri, identity.Email)
nextPage, err := ah.AppDao().PrepareSsoRedirect(ctx, redirectUri, identity.Email, ah.opts.JWT)
if err != nil {
zap.L().Error("[receiveGoogleAuth] failed to generate redirect URI after successful login ", zap.String("domain", domain.String()), zap.Error(err))
handleSsoError(w, r, redirectUri)
Expand Down Expand Up @@ -331,7 +331,7 @@ func (ah *APIHandler) receiveSAML(w http.ResponseWriter, r *http.Request) {
return
}

nextPage, err := ah.AppDao().PrepareSsoRedirect(ctx, redirectUri, email)
nextPage, err := ah.AppDao().PrepareSsoRedirect(ctx, redirectUri, email, ah.opts.JWT)
if err != nil {
zap.L().Error("[receiveSAML] failed to generate redirect URI after successful login ", zap.String("domain", domain.String()), zap.Error(err))
handleSsoError(w, r, redirectUri)
Expand Down
2 changes: 1 addition & 1 deletion ee/query-service/app/api/cloudIntegrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ func (ah *APIHandler) CloudIntegrationsGenerateConnectionParams(w http.ResponseW
return
}

currentUser, err := auth.GetUserFromRequest(r)
currentUser, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, basemodel.UnauthorizedError(fmt.Errorf(
"couldn't deduce current user: %w", err,
Expand Down
8 changes: 4 additions & 4 deletions ee/query-service/app/api/pat.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (ah *APIHandler) createPAT(w http.ResponseWriter, r *http.Request) {
RespondError(w, model.BadRequest(err), nil)
return
}
user, err := auth.GetUserFromRequest(r)
user, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Expand Down Expand Up @@ -97,7 +97,7 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) {
return
}

user, err := auth.GetUserFromRequest(r)
user, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Expand Down Expand Up @@ -127,7 +127,7 @@ func (ah *APIHandler) updatePAT(w http.ResponseWriter, r *http.Request) {

func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
user, err := auth.GetUserFromRequest(r)
user, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Expand All @@ -147,7 +147,7 @@ func (ah *APIHandler) getPATs(w http.ResponseWriter, r *http.Request) {
func (ah *APIHandler) revokePAT(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()
id := mux.Vars(r)["id"]
user, err := auth.GetUserFromRequest(r)
user, err := auth.GetUserFromReqContext(r.Context())
if err != nil {
RespondError(w, &model.ApiError{
Typ: model.ErrorUnauthorized,
Expand Down
24 changes: 13 additions & 11 deletions ee/query-service/app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ import (
"go.signoz.io/signoz/ee/query-service/interfaces"
"go.signoz.io/signoz/ee/query-service/rules"
"go.signoz.io/signoz/pkg/http/middleware"
baseauth "go.signoz.io/signoz/pkg/query-service/auth"
v3 "go.signoz.io/signoz/pkg/query-service/model/v3"
"go.signoz.io/signoz/pkg/signoz"
"go.signoz.io/signoz/pkg/types/authtypes"
"go.signoz.io/signoz/pkg/web"

licensepkg "go.signoz.io/signoz/ee/query-service/license"
Expand Down Expand Up @@ -80,6 +80,7 @@ type ServerOptions struct {
GatewayUrl string
UseLogsNewSchema bool
UseTraceNewSchema bool
Jwt *authtypes.JWT
}

// Server runs HTTP api service
Expand Down Expand Up @@ -269,6 +270,7 @@ func NewServer(serverOptions *ServerOptions) (*Server, error) {
GatewayUrl: serverOptions.GatewayUrl,
UseLogsNewSchema: serverOptions.UseLogsNewSchema,
UseTraceNewSchema: serverOptions.UseTraceNewSchema,
JWT: serverOptions.Jwt,
}

apiHandler, err := api.NewAPIHandler(apiOpts)
Expand Down Expand Up @@ -311,6 +313,7 @@ func (s *Server) createPrivateServer(apiHandler *api.APIHandler) (*http.Server,

r := baseapp.NewRouter()

r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt).Wrap)
r.Use(middleware.NewTimeout(zap.L(),
s.serverOptions.Config.APIServer.Timeout.ExcludedRoutes,
s.serverOptions.Config.APIServer.Timeout.Default,
Expand Down Expand Up @@ -342,8 +345,8 @@ func (s *Server) createPublicServer(apiHandler *api.APIHandler, web web.Web) (*h
r := baseapp.NewRouter()

// add auth middleware
getUserFromRequest := func(r *http.Request) (*basemodel.UserPayload, error) {
user, err := auth.GetUserFromRequest(r, apiHandler)
getUserFromRequest := func(ctx context.Context) (*basemodel.UserPayload, error) {
user, err := auth.GetUserFromRequestContext(ctx, apiHandler)

if err != nil {
return nil, err
Expand All @@ -357,6 +360,7 @@ func (s *Server) createPublicServer(apiHandler *api.APIHandler, web web.Web) (*h
}
am := baseapp.NewAuthMiddleware(getUserFromRequest)

r.Use(middleware.NewAuth(zap.L(), s.serverOptions.Jwt).Wrap)
r.Use(middleware.NewTimeout(zap.L(),
s.serverOptions.Config.APIServer.Timeout.ExcludedRoutes,
s.serverOptions.Config.APIServer.Timeout.Default,
Expand Down Expand Up @@ -497,8 +501,8 @@ func extractQueryRangeData(path string, r *http.Request) (map[string]interface{}
data["queryType"] = queryInfoResult.QueryType
data["panelType"] = queryInfoResult.PanelType

userEmail, err := baseauth.GetEmailFromJwt(r.Context())
if err == nil {
claims, ok := authtypes.GetClaimsFromContext(r.Context())
if ok {
// switch case to set data["screen"] based on the referrer
switch {
case dashboardMatched:
Expand All @@ -513,7 +517,7 @@ func extractQueryRangeData(path string, r *http.Request) (map[string]interface{}
data["screen"] = "unknown"
return data, true
}
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_QUERY_RANGE_API, data, userEmail, true, false)
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_QUERY_RANGE_API, data, claims.Email, true, false)
}
}
return data, true
Expand All @@ -535,8 +539,6 @@ func getActiveLogs(path string, r *http.Request) {

func (s *Server) analyticsMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := baseauth.AttachJwtToContext(r.Context(), r)
r = r.WithContext(ctx)
route := mux.CurrentRoute(r)
path, _ := route.GetPathTemplate()

Expand All @@ -554,9 +556,9 @@ func (s *Server) analyticsMiddleware(next http.Handler) http.Handler {
}

if _, ok := telemetry.EnabledPaths()[path]; ok {
userEmail, err := baseauth.GetEmailFromJwt(r.Context())
if err == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_PATH, data, userEmail, true, false)
claims, ok := authtypes.GetClaimsFromContext(r.Context())
if ok {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_PATH, data, claims.Email, true, false)
}
}

Expand Down
10 changes: 5 additions & 5 deletions ee/query-service/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@ package auth
import (
"context"
"fmt"
"net/http"
"time"

"go.signoz.io/signoz/ee/query-service/app/api"
baseauth "go.signoz.io/signoz/pkg/query-service/auth"
basemodel "go.signoz.io/signoz/pkg/query-service/model"
"go.signoz.io/signoz/pkg/query-service/telemetry"
"go.signoz.io/signoz/pkg/types/authtypes"

"go.uber.org/zap"
)

func GetUserFromRequest(r *http.Request, apiHandler *api.APIHandler) (*basemodel.UserPayload, error) {
patToken := r.Header.Get("SIGNOZ-API-KEY")
if len(patToken) > 0 {
func GetUserFromRequestContext(ctx context.Context, apiHandler *api.APIHandler) (*basemodel.UserPayload, error) {
patToken, ok := ctx.Value(authtypes.PatTokenKey).(string)
if ok && patToken != "" {
zap.L().Debug("Received a non-zero length PAT token")
ctx := context.Background()
dao := apiHandler.AppDao()
Expand Down Expand Up @@ -52,5 +52,5 @@ func GetUserFromRequest(r *http.Request, apiHandler *api.APIHandler) (*basemodel
return nil, err
}
}
return baseauth.GetUserFromRequest(r)
return baseauth.GetUserFromReqContext(ctx)
}
3 changes: 2 additions & 1 deletion ee/query-service/dao/interface.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
basedao "go.signoz.io/signoz/pkg/query-service/dao"
baseint "go.signoz.io/signoz/pkg/query-service/interfaces"
basemodel "go.signoz.io/signoz/pkg/query-service/model"
"go.signoz.io/signoz/pkg/types/authtypes"
)

type ModelDao interface {
Expand All @@ -22,7 +23,7 @@ type ModelDao interface {

// auth methods
CanUsePassword(ctx context.Context, email string) (bool, basemodel.BaseApiError)
PrepareSsoRedirect(ctx context.Context, redirectUri, email string) (redirectURL string, apierr basemodel.BaseApiError)
PrepareSsoRedirect(ctx context.Context, redirectUri, email string, jwt *authtypes.JWT) (redirectURL string, apierr basemodel.BaseApiError)
GetDomainFromSsoResponse(ctx context.Context, relayState *url.URL) (*model.OrgDomain, error)

// org domain (auth domains) CRUD ops
Expand Down
5 changes: 3 additions & 2 deletions ee/query-service/dao/sqlite/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import (
baseconst "go.signoz.io/signoz/pkg/query-service/constants"
basemodel "go.signoz.io/signoz/pkg/query-service/model"
"go.signoz.io/signoz/pkg/query-service/utils"
"go.signoz.io/signoz/pkg/types/authtypes"
"go.uber.org/zap"
)

Expand Down Expand Up @@ -64,7 +65,7 @@ func (m *modelDao) createUserForSAMLRequest(ctx context.Context, email string) (

// PrepareSsoRedirect prepares redirect page link after SSO response
// is successfully parsed (i.e. valid email is available)
func (m *modelDao) PrepareSsoRedirect(ctx context.Context, redirectUri, email string) (redirectURL string, apierr basemodel.BaseApiError) {
func (m *modelDao) PrepareSsoRedirect(ctx context.Context, redirectUri, email string, jwt *authtypes.JWT) (redirectURL string, apierr basemodel.BaseApiError) {

userPayload, apierr := m.GetUserByEmail(ctx, email)
if !apierr.IsNil() {
Expand All @@ -85,7 +86,7 @@ func (m *modelDao) PrepareSsoRedirect(ctx context.Context, redirectUri, email st
user = &userPayload.User
}

tokenStore, err := baseauth.GenerateJWTForUser(user)
tokenStore, err := baseauth.GenerateJWTForUser(user, jwt)
if err != nil {
zap.L().Error("failed to generate token for SSO login user", zap.Error(err))
return "", model.InternalErrorStr("failed to generate token for the user")
Expand Down
8 changes: 4 additions & 4 deletions ee/query-service/license/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@ import (

"sync"

"go.signoz.io/signoz/pkg/query-service/auth"
baseconstants "go.signoz.io/signoz/pkg/query-service/constants"
"go.signoz.io/signoz/pkg/types/authtypes"

validate "go.signoz.io/signoz/ee/query-service/integrations/signozio"
"go.signoz.io/signoz/ee/query-service/model"
Expand Down Expand Up @@ -237,10 +237,10 @@ func (lm *Manager) ValidateV3(ctx context.Context) (reterr error) {
func (lm *Manager) ActivateV3(ctx context.Context, licenseKey string) (licenseResponse *model.LicenseV3, errResponse *model.ApiError) {
defer func() {
if errResponse != nil {
userEmail, err := auth.GetEmailFromJwt(ctx)
if err == nil {
claims, ok := authtypes.GetClaimsFromContext(ctx)
if ok {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_LICENSE_ACT_FAILED,
map[string]interface{}{"err": errResponse.Err.Error()}, userEmail, true, false)
map[string]interface{}{"err": errResponse.Err.Error()}, claims.Email, true, false)
}
}
}()
Expand Down
21 changes: 12 additions & 9 deletions ee/query-service/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
baseconst "go.signoz.io/signoz/pkg/query-service/constants"
"go.signoz.io/signoz/pkg/query-service/version"
"go.signoz.io/signoz/pkg/signoz"
"go.signoz.io/signoz/pkg/types/authtypes"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"

Expand Down Expand Up @@ -154,6 +155,16 @@ func main() {
zap.L().Fatal("Failed to create signoz struct", zap.Error(err))
}

jwtSecret := os.Getenv("SIGNOZ_JWT_SECRET")

if len(jwtSecret) == 0 {
zap.L().Warn("No JWT secret key is specified.")
} else {
zap.L().Info("JWT secret key set successfully.")
}

jwt := authtypes.NewJWT(jwtSecret, 30*time.Minute, 30*24*time.Hour)

serverOptions := &app.ServerOptions{
Config: config,
SigNoz: signoz,
Expand All @@ -171,15 +182,7 @@ func main() {
GatewayUrl: gatewayUrl,
UseLogsNewSchema: useLogsNewSchema,
UseTraceNewSchema: useTraceNewSchema,
}

// Read the jwt secret key
auth.JwtSecret = os.Getenv("SIGNOZ_JWT_SECRET")

if len(auth.JwtSecret) == 0 {
zap.L().Warn("No JWT secret key is specified.")
} else {
zap.L().Info("JWT secret key set successfully.")
Jwt: jwt,
}

server, err := app.NewServer(serverOptions)
Expand Down
18 changes: 7 additions & 11 deletions pkg/http/middleware/analytics.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,10 @@ import (
"net/http"
"regexp"

// TODO(remove): Remove auth packages
"go.signoz.io/signoz/pkg/query-service/auth"

"github.com/gorilla/mux"
v3 "go.signoz.io/signoz/pkg/query-service/model/v3"
"go.signoz.io/signoz/pkg/query-service/telemetry"
"go.signoz.io/signoz/pkg/types/authtypes"
"go.uber.org/zap"
)

Expand All @@ -30,8 +28,6 @@ func NewAnalytics(logger *zap.Logger) *Analytics {

func (a *Analytics) Wrap(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := auth.AttachJwtToContext(r.Context(), r)
r = r.WithContext(ctx)
route := mux.CurrentRoute(r)
path, _ := route.GetPathTemplate()

Expand All @@ -50,9 +46,9 @@ func (a *Analytics) Wrap(next http.Handler) http.Handler {
}

if _, ok := telemetry.EnabledPaths()[path]; ok {
userEmail, err := auth.GetEmailFromJwt(r.Context())
if err == nil {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_PATH, data, userEmail, true, false)
claims, ok := authtypes.GetClaimsFromContext(r.Context())
if ok {
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_PATH, data, claims.Email, true, false)
}
}

Expand Down Expand Up @@ -138,8 +134,8 @@ func (a *Analytics) extractQueryRangeData(path string, r *http.Request) (map[str
data["queryType"] = queryInfoResult.QueryType
data["panelType"] = queryInfoResult.PanelType

userEmail, err := auth.GetEmailFromJwt(r.Context())
if err == nil {
claims, ok := authtypes.GetClaimsFromContext(r.Context())
if ok {
// switch case to set data["screen"] based on the referrer
switch {
case dashboardMatched:
Expand All @@ -154,7 +150,7 @@ func (a *Analytics) extractQueryRangeData(path string, r *http.Request) (map[str
data["screen"] = "unknown"
return data, true
}
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_QUERY_RANGE_API, data, userEmail, true, false)
telemetry.GetInstance().SendEvent(telemetry.TELEMETRY_EVENT_QUERY_RANGE_API, data, claims.Email, true, false)
}
}
return data, true
Expand Down
Loading
Loading