diff --git a/pkg/alertmanager/alertmanagerstore/sqlalertmanagerstore/state.go b/pkg/alertmanager/alertmanagerstore/sqlalertmanagerstore/state.go index e7407fc220..008db2cda3 100644 --- a/pkg/alertmanager/alertmanagerstore/sqlalertmanagerstore/state.go +++ b/pkg/alertmanager/alertmanagerstore/sqlalertmanagerstore/state.go @@ -3,8 +3,6 @@ package sqlalertmanagerstore import ( "context" "database/sql" - "encoding/base64" - "time" "go.signoz.io/signoz/pkg/errors" "go.signoz.io/signoz/pkg/sqlstore" @@ -20,76 +18,52 @@ func NewStateStore(sqlstore sqlstore.SQLStore) alertmanagertypes.StateStore { } // Get implements alertmanagertypes.StateStore. -func (store *state) Get(ctx context.Context, orgID string, stateName alertmanagertypes.StateName) (string, error) { - storeableConfig := new(alertmanagertypes.StoreableConfig) +func (store *state) Get(ctx context.Context, orgID string) (*alertmanagertypes.StoreableState, error) { + storeableState := new(alertmanagertypes.StoreableState) err := store. sqlstore. BunDB(). NewSelect(). - Model(storeableConfig). + Model(storeableState). Where("org_id = ?", orgID). Scan(ctx) if err != nil { if err == sql.ErrNoRows { - return "", errors.Newf(errors.TypeNotFound, alertmanagertypes.ErrCodeAlertmanagerStateNotFound, "cannot find alertmanager state for org %s", orgID) + return nil, errors.Newf(errors.TypeNotFound, alertmanagertypes.ErrCodeAlertmanagerStateNotFound, "cannot find alertmanager state for org %s", orgID) } - return "", err + return nil, err } - if stateName == alertmanagertypes.SilenceStateName { - decodedState, err := base64.RawStdEncoding.DecodeString(storeableConfig.SilencesState) - if err != nil { - return "", err - } - - return string(decodedState), nil - } - - if stateName == alertmanagertypes.NFLogStateName { - decodedState, err := base64.RawStdEncoding.DecodeString(storeableConfig.NFLogState) - if err != nil { - return "", err - } - - return string(decodedState), nil - } - - // This should never happen - return "", errors.Newf(errors.TypeNotFound, alertmanagertypes.ErrCodeAlertmanagerStateNameInvalid, "cannot find state with name %s for org %s", stateName.String(), orgID) + return storeableState, nil } // Set implements alertmanagertypes.StateStore. -func (store *state) Set(ctx context.Context, orgID string, stateName alertmanagertypes.StateName, state alertmanagertypes.State) (int64, error) { - storeableConfig := new(alertmanagertypes.StoreableConfig) - - marshalledState, err := state.MarshalBinary() +func (store *state) Set(ctx context.Context, orgID string, storeableState *alertmanagertypes.StoreableState) error { + tx, err := store.sqlstore.BunDB().BeginTx(ctx, nil) if err != nil { - return 0, err + return err } - encodedState := base64.StdEncoding.EncodeToString(marshalledState) - - q := store. - sqlstore. - BunDB(). - NewUpdate(). - Model(storeableConfig). - Set("updated_at = ?", time.Now()). - Where("org_id = ?", orgID) - if stateName == alertmanagertypes.SilenceStateName { - q.Set("silences_state = ?", encodedState) - } + defer tx.Rollback() //nolint:errcheck - if stateName == alertmanagertypes.NFLogStateName { - q.Set("nflog_state = ?", encodedState) + _, err = tx. + NewInsert(). + Model(storeableState). + On("CONFLICT (org_id) DO UPDATE"). + Set("silences = EXCLUDED.silences"). + Set("nflog = EXCLUDED.nflog"). + Set("updated_at = EXCLUDED.updated_at"). + Where("org_id = ?", orgID). + Exec(ctx) + if err != nil { + return err } - _, err = q.Exec(ctx) - if err != nil { - return 0, err + if err := tx.Commit(); err != nil { + return err } - return int64(len(marshalledState)), nil + return nil } diff --git a/pkg/alertmanager/server/server.go b/pkg/alertmanager/server/server.go index 7e35b400ea..06252beec8 100644 --- a/pkg/alertmanager/server/server.go +++ b/pkg/alertmanager/server/server.go @@ -77,20 +77,18 @@ func New(ctx context.Context, logger *slog.Logger, registry prometheus.Registere server.marker = alertmanagertypes.NewMarker(server.registry) // get silences for initial state - silencesstate, err := server.stateStore.Get(ctx, server.orgID, alertmanagertypes.SilenceStateName) + state, err := server.stateStore.Get(ctx, server.orgID) if err != nil && !errors.Ast(err, errors.TypeNotFound) { return nil, err } - // get nflog for initial state - nflogstate, err := server.stateStore.Get(ctx, server.orgID, alertmanagertypes.NFLogStateName) - if err != nil && !errors.Ast(err, errors.TypeNotFound) { - return nil, err + silencesSnapshot := "" + if state != nil { + silencesSnapshot = state.Silences } - // Initialize silences server.silences, err = silence.New(silence.Options{ - SnapshotReader: strings.NewReader(silencesstate), + SnapshotReader: strings.NewReader(silencesSnapshot), Retention: srvConfig.Silences.Retention, Limits: silence.Limits{ MaxSilences: func() int { return srvConfig.Silences.Max }, @@ -103,9 +101,14 @@ func New(ctx context.Context, logger *slog.Logger, registry prometheus.Registere return nil, err } + nflogSnapshot := "" + if state != nil { + nflogSnapshot = state.NFLog + } + // Initialize notification log server.nflog, err = nflog.New(nflog.Options{ - SnapshotReader: strings.NewReader(nflogstate), + SnapshotReader: strings.NewReader(nflogSnapshot), Retention: server.srvConfig.NFLog.Retention, Metrics: server.registry, Logger: server.logger, @@ -125,7 +128,21 @@ func New(ctx context.Context, logger *slog.Logger, registry prometheus.Registere // Don't return here - we need to snapshot our state first. } - return server.stateStore.Set(ctx, server.orgID, alertmanagertypes.SilenceStateName, server.silences) + state, err := server.stateStore.Get(ctx, server.orgID) + if err != nil && !errors.Ast(err, errors.TypeNotFound) { + return 0, err + } + + if state == nil { + state = alertmanagertypes.NewStoreableState(server.orgID) + } + + c, err := state.Set(alertmanagertypes.SilenceStateName, server.silences) + if err != nil { + return 0, err + } + + return c, server.stateStore.Set(ctx, server.orgID, state) }) }() @@ -140,7 +157,21 @@ func New(ctx context.Context, logger *slog.Logger, registry prometheus.Registere // Don't return without saving the current state. } - return server.stateStore.Set(ctx, server.orgID, alertmanagertypes.NFLogStateName, server.nflog) + state, err := server.stateStore.Get(ctx, server.orgID) + if err != nil && !errors.Ast(err, errors.TypeNotFound) { + return 0, err + } + + if state == nil { + state = alertmanagertypes.NewStoreableState(server.orgID) + } + + c, err := state.Set(alertmanagertypes.NFLogStateName, server.nflog) + if err != nil { + return 0, err + } + + return c, server.stateStore.Set(ctx, server.orgID, state) }) }() diff --git a/pkg/types/alertmanagertypes/alertmanagertypestest/state.go b/pkg/types/alertmanagertypes/alertmanagertypestest/state.go index f52f59e364..4aa4b5bbaf 100644 --- a/pkg/types/alertmanagertypes/alertmanagertypestest/state.go +++ b/pkg/types/alertmanagertypes/alertmanagertypestest/state.go @@ -2,54 +2,36 @@ package alertmanagertypestest import ( "context" - "encoding/base64" "sync" "go.signoz.io/signoz/pkg/errors" "go.signoz.io/signoz/pkg/types/alertmanagertypes" ) +var _ alertmanagertypes.StateStore = (*StateStore)(nil) + type StateStore struct { - states map[string]map[string]string + states map[string]*alertmanagertypes.StoreableState mtx sync.RWMutex } func NewStateStore() *StateStore { return &StateStore{ - states: make(map[string]map[string]string), + states: make(map[string]*alertmanagertypes.StoreableState), } } -func (s *StateStore) Set(ctx context.Context, orgID string, stateName alertmanagertypes.StateName, state alertmanagertypes.State) (int64, error) { - if _, ok := s.states[orgID]; !ok { - s.states[orgID] = make(map[string]string) - } - - bytes, err := state.MarshalBinary() - if err != nil { - return 0, err - } - +func (s *StateStore) Set(ctx context.Context, orgID string, storeableState *alertmanagertypes.StoreableState) error { s.mtx.Lock() - s.states[orgID][stateName.String()] = base64.StdEncoding.EncodeToString(bytes) + s.states[orgID] = storeableState s.mtx.Unlock() - return int64(len(bytes)), nil + return nil } -func (s *StateStore) Get(ctx context.Context, orgID string, stateName alertmanagertypes.StateName) (string, error) { +func (s *StateStore) Get(ctx context.Context, orgID string) (*alertmanagertypes.StoreableState, error) { if _, ok := s.states[orgID]; !ok { - return "", errors.Newf(errors.TypeNotFound, alertmanagertypes.ErrCodeAlertmanagerStateNotFound, "state %q for orgID %q not found", stateName.String(), orgID) - } - - state, ok := s.states[orgID][stateName.String()] - if !ok { - return "", errors.Newf(errors.TypeNotFound, alertmanagertypes.ErrCodeAlertmanagerStateNotFound, "state %q for orgID %q not found", stateName.String(), orgID) - } - - bytes, err := base64.StdEncoding.DecodeString(state) - if err != nil { - return "", err + return nil, errors.Newf(errors.TypeNotFound, alertmanagertypes.ErrCodeAlertmanagerStateNotFound, "state for orgID %q not found", orgID) } - return string(bytes), nil + return s.states[orgID], nil } diff --git a/pkg/types/alertmanagertypes/config.go b/pkg/types/alertmanagertypes/config.go index b0b5e0ed59..58c5750a7a 100644 --- a/pkg/types/alertmanagertypes/config.go +++ b/pkg/types/alertmanagertypes/config.go @@ -39,13 +39,11 @@ type RouteConfig struct { type StoreableConfig struct { bun.BaseModel `bun:"table:alertmanager_config"` - ID uint64 `bun:"id"` - Config string `bun:"config"` - SilencesState string `bun:"silences_state,nullzero"` - NFLogState string `bun:"nflog_state,nullzero"` - CreatedAt time.Time `bun:"created_at"` - UpdatedAt time.Time `bun:"updated_at"` - OrgID string `bun:"org_id"` + ID uint64 `bun:"id,pk,autoincrement"` + Config string `bun:"config"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + OrgID string `bun:"org_id"` } // Config is the type for the entire alertmanager configuration @@ -68,12 +66,10 @@ func NewConfig(c *config.Config, orgID string) *Config { return &Config{ alertmanagerConfig: c, storeableConfig: &StoreableConfig{ - Config: string(newRawFromConfig(c)), - SilencesState: "", - NFLogState: "", - CreatedAt: time.Now(), - UpdatedAt: time.Now(), - OrgID: orgID, + Config: string(newRawFromConfig(c)), + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + OrgID: orgID, }, channels: channels, } diff --git a/pkg/types/alertmanagertypes/state.go b/pkg/types/alertmanagertypes/state.go index 657b1c34ff..25ac871e10 100644 --- a/pkg/types/alertmanagertypes/state.go +++ b/pkg/types/alertmanagertypes/state.go @@ -2,8 +2,11 @@ package alertmanagertypes import ( "context" + "encoding/base64" + "time" "github.com/prometheus/alertmanager/cluster" + "github.com/uptrace/bun" "go.signoz.io/signoz/pkg/errors" ) @@ -19,10 +22,69 @@ var ( ) var ( - ErrCodeAlertmanagerStateNotFound = errors.MustNewCode("alertmanager_state_not_found") - ErrCodeAlertmanagerStateNameInvalid = errors.MustNewCode("alertmanager_state_name_invalid") + ErrCodeAlertmanagerStateNotFound = errors.MustNewCode("alertmanager_state_not_found") ) +type StoreableState struct { + bun.BaseModel `bun:"table:alertmanager_state"` + + ID uint64 `bun:"id,pk,autoincrement"` + Silences string `bun:"silences,nullzero"` + NFLog string `bun:"nflog,nullzero"` + CreatedAt time.Time `bun:"created_at"` + UpdatedAt time.Time `bun:"updated_at"` + OrgID string `bun:"org_id"` +} + +func NewStoreableState(orgID string) *StoreableState { + return &StoreableState{ + OrgID: orgID, + CreatedAt: time.Now(), + UpdatedAt: time.Now(), + } +} + +func (s *StoreableState) Set(stateName StateName, state State) (int64, error) { + marshalledState, err := state.MarshalBinary() + if err != nil { + return 0, err + } + encodedState := base64.StdEncoding.EncodeToString(marshalledState) + + switch stateName { + case SilenceStateName: + s.Silences = encodedState + case NFLogStateName: + s.NFLog = encodedState + } + + s.UpdatedAt = time.Now() + + return int64(len(marshalledState)), nil +} + +func (s *StoreableState) Get(stateName StateName) (string, error) { + base64encodedState := "" + + switch stateName { + case SilenceStateName: + base64encodedState = s.Silences + case NFLogStateName: + base64encodedState = s.NFLog + } + + if base64encodedState == "" { + return "", errors.New(errors.TypeNotFound, ErrCodeAlertmanagerStateNotFound, "state not found") + } + + decodedState, err := base64.StdEncoding.DecodeString(base64encodedState) + if err != nil { + return "", err + } + + return string(decodedState), nil +} + type StateName struct { name string } @@ -36,9 +98,9 @@ type StateStore interface { // The return type matches the return of `silence.Maintenance` or `nflog.Maintenance`. // See https://github.com/prometheus/alertmanager/blob/3b06b97af4d146e141af92885a185891eb79a5b0/silence/silence.go#L217 // and https://github.com/prometheus/alertmanager/blob/3b06b97af4d146e141af92885a185891eb79a5b0/nflog/nflog.go#L94 - Set(context.Context, string, StateName, State) (int64, error) + Set(context.Context, string, *StoreableState) error // Gets the silence state or the notification log state as a string from the store. This is used as a snapshot to load the // initial state of silences or notification log when starting the alertmanager. - Get(context.Context, string, StateName) (string, error) + Get(context.Context, string) (*StoreableState, error) }