Skip to content

Commit

Permalink
disable excess items crossing premium limits on premium removal from …
Browse files Browse the repository at this point in the history
…a server (botlabs-gg#1446)

* disable excess items crossing premium limits on premium disable

* make it errorf

---------

Co-authored-by: Ashish Jhanwar <ashishjh-bst@users.noreply.github.com>
  • Loading branch information
ashishjh-bst and ashishjh-bst committed Feb 17, 2023
1 parent 8a7d047 commit 2d78392
Show file tree
Hide file tree
Showing 9 changed files with 97 additions and 23 deletions.
16 changes: 16 additions & 0 deletions customcommands/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,22 @@ const (
CCMessageExecLimitPremium = 5
)

func (p *Plugin) OnRemovedPremiumGuild(GuildID int64) error {
commands, err := models.CustomCommands(qm.Where("guild_id = ?", GuildID), qm.Offset(MaxCommands)).AllG(context.Background())
if err != nil {
return errors.WrapIf(err, "failed getting custom commands")
}

if len(commands) > 0 {
_, err = commands.UpdateAllG(context.Background(), models.M{"disabled": true})
if err != nil {
return errors.WrapIf(err, "failed disabling custom commands on premium removal")
}
}

return nil
}

var metricsExecutedCommands = promauto.NewCounterVec(prometheus.CounterOpts{
Name: "yagpdb_cc_triggered_total",
Help: "Number custom commands triggered",
Expand Down
24 changes: 17 additions & 7 deletions customcommands/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -277,11 +277,21 @@ func handleUpdateCommand(w http.ResponseWriter, r *http.Request) (web.TemplateDa
ctx := r.Context()
activeGuild, templateData := web.GetBaseCPContextData(ctx)

cmd := ctx.Value(common.ContextKeyParsedForm).(*CustomCommand)
cmdEdit := ctx.Value(common.ContextKeyParsedForm).(*CustomCommand)
cmdSaved, err := models.FindCustomCommandG(context.Background(), activeGuild.ID, int64(cmdEdit.ID))
if cmdSaved.Disabled == true && cmdEdit.ToDBModel().Disabled == false {
c, err := models.CustomCommands(qm.Where("guild_id = ?", activeGuild.ID)).CountG(ctx)
if err != nil {
return templateData, err
}
if int(c) >= MaxCommandsForContext(ctx) {
return templateData, web.NewPublicError(fmt.Sprintf("Max %d enabled custom commands allowed (or %d for premium servers)", MaxCommands, MaxCommandsPremium))
}
}

// ensure that the group specified is owned by this guild
if cmd.GroupID != 0 {
c, err := models.CustomCommandGroups(qm.Where("guild_id = ? AND id = ?", activeGuild.ID, cmd.GroupID)).CountG(ctx)
if cmdEdit.GroupID != 0 {
c, err := models.CustomCommandGroups(qm.Where("guild_id = ? AND id = ?", activeGuild.ID, cmdEdit.GroupID)).CountG(ctx)
if err != nil {
return templateData, err
}
Expand All @@ -291,13 +301,13 @@ func handleUpdateCommand(w http.ResponseWriter, r *http.Request) (web.TemplateDa
}
}

dbModel := cmd.ToDBModel()
dbModel := cmdEdit.ToDBModel()

templateData["CurrentGroupID"] = dbModel.GroupID.Int64

dbModel.GuildID = activeGuild.ID
dbModel.LocalID = cmd.ID
dbModel.TriggerType = int(triggerTypeFromForm(cmd.TriggerTypeForm))
dbModel.LocalID = cmdEdit.ID
dbModel.TriggerType = int(triggerTypeFromForm(cmdEdit.TriggerTypeForm))

// check low interval limits
if dbModel.TriggerType == int(CommandTriggerInterval) && dbModel.TimeTriggerInterval <= 10 {
Expand All @@ -311,7 +321,7 @@ func handleUpdateCommand(w http.ResponseWriter, r *http.Request) (web.TemplateDa
}
}

_, err := dbModel.UpdateG(ctx, boil.Blacklist("last_run", "next_run", "local_id", "guild_id", "last_error", "last_error_time", "run_count"))
_, err = dbModel.UpdateG(ctx, boil.Blacklist("last_run", "next_run", "local_id", "guild_id", "last_error", "last_error_time", "run_count"))
if err != nil {
return templateData, nil
}
Expand Down
19 changes: 19 additions & 0 deletions reddit/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/botlabs-gg/yagpdb/v2/lib/dcmd"
"github.com/botlabs-gg/yagpdb/v2/reddit/models"
"github.com/botlabs-gg/yagpdb/v2/stdcommands/util"
"github.com/volatiletech/sqlboiler/queries/qm"
)

var _ bot.RemoveGuildHandler = (*Plugin)(nil)
Expand All @@ -26,6 +27,24 @@ func (p *Plugin) RemoveGuild(g int64) error {
return nil
}

func (p *Plugin) OnRemovedPremiumGuild(guildID int64) error {
logger.WithField("guild_id", guildID).Infof("Removed Excess Reddit Feeds")
feeds, err := models.RedditFeeds(qm.Where("guild_id = ? and disabled = ?", guildID, false), qm.Offset(GuildMaxFeedsNormal)).AllG(context.Background())

if err != nil {
return errors.WrapIf(err, "failed getting reddit feeds")
}

if len(feeds) > 0 {
_, err = feeds.UpdateAllG(context.Background(), models.M{"disabled": true})
if err != nil {
return errors.WrapIf(err, "failed disabling reddit feeds on premium removal")
}
}

return nil
}

func (p *Plugin) AddCommands() {
commands.AddRootCommands(p, &commands.YAGCommand{
CmdCategory: commands.CategoryDebug,
Expand Down
1 change: 0 additions & 1 deletion reddit/migrate_to_psql.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,6 @@ func FindLegacyWatchItem(source []*LegacySubredditWatchItem, id int) *LegacySubr
for _, c := range source {
if c.ID == id {
return c
break
}
}
return nil
Expand Down
4 changes: 2 additions & 2 deletions reddit/plugin_web.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ const (
type CreateForm struct {
Subreddit string `schema:"subreddit" valid:",1,100"`
Slow bool `schema:"slow"`
Channel int64 `schema:"channel" valid:"channel,true`
Channel int64 `schema:"channel" valid:"channel,true"`
ID int64 `schema:"id"`
UseEmbeds bool `schema:"use_embeds"`
NSFWMode int `schema:"nsfw_filter"`
MinUpvotes int `schema:"min_upvotes" valid:"0,"`
}

type UpdateForm struct {
Channel int64 `schema:"channel" valid:"channel,true`
Channel int64 `schema:"channel" valid:"channel,true"`
ID int64 `schema:"id"`
UseEmbeds bool `schema:"use_embeds"`
NSFWMode int `schema:"nsfw_filter"`
Expand Down
10 changes: 10 additions & 0 deletions twitter/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,13 @@ func (p *Plugin) DisableFeed(elem *mqueue.QueuedElement, err error) {
logger.WithError(err).WithField("feed_id", feedID).Error("failed removing feed")
}
}

func (p *Plugin) OnRemovedPremiumGuild(guildID int64) error {
logger.WithField("guild_id", guildID).Infof("Removed Excess Twitter Feeds")
_, err := models.TwitterFeeds(models.TwitterFeedWhere.GuildID.EQ(int64(guildID))).UpdateAllG(context.Background(), models.M{"enabled": false})
if err != nil {
logger.WithError(err).WithField("guild_id", guildID).Error("failed disabling feed for missing premium")
return err
}
return nil
}
6 changes: 0 additions & 6 deletions twitter/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ import (
//go:embed assets/twitter.html
var PageHTML string

type CtxKey int

const (
CurrentConfig CtxKey = iota
)

type Form struct {
TwitterUser string `valid:",1,256"`
DiscordChannel int64 `valid:"channel,false"`
Expand Down
21 changes: 21 additions & 0 deletions youtube/bot.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,24 @@ func (p *Plugin) Status() (string, string) {

return "Youtube", fmt.Sprintf("%d/%d", unique, numChannels)
}

func (p *Plugin) OnRemovedPremiumGuild(guildID int64) error {
logger.WithField("guild_id", guildID).Infof("Removed Excess Youtube Feeds")
feeds := make([]ChannelSubscription, 0)
err := common.GORM.Model(&ChannelSubscription{}).Where(`guild_id = ? and enabled = ?`, guildID, common.BoolToPointer(true)).Offset(GuildMaxFeeds).Order(
"id desc",
).Find(&feeds).Error
if err != nil {
logger.WithError(err).Errorf("failed getting feed ids for guild_id %d", guildID)
return err
}

if len(feeds) > 0 {
err = common.GORM.Model(&feeds).Update(ChannelSubscription{Enabled: common.BoolToPointer(false)}).Error
if err != nil {
logger.WithError(err).Errorf("failed getting feed ids for guild_id %d", guildID)
return err
}
}
return nil
}
19 changes: 12 additions & 7 deletions youtube/web.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,6 @@ import (
"goji.io/pat"
)

type CtxKey int

const (
CurrentConfig CtxKey = iota
)

//go:embed assets/youtube.html
var PageHTML string

Expand Down Expand Up @@ -163,7 +157,6 @@ func (p *Plugin) HandleNew(w http.ResponseWriter, r *http.Request) (web.Template
// limit it to max 25 feeds
var count int
common.GORM.Model(&ChannelSubscription{}).Where("guild_id = ?", activeGuild.ID).Count(&count)

if count >= MaxFeedsForContext(ctx) {
return templateData.AddAlerts(web.ErrorAlert(fmt.Sprintf("Max %d youtube feeds allowed (%d for premium servers)", GuildMaxFeeds, GuildMaxFeedsPremium))), nil
}
Expand Down Expand Up @@ -237,6 +230,18 @@ func (p *Plugin) HandleEdit(w http.ResponseWriter, r *http.Request) (templateDat
sub.PublishShorts = &data.PublishShorts
sub.ChannelID = discordgo.StrID(data.DiscordChannel)
sub.Enabled = &data.Enabled
count := 0
common.GORM.Model(&ChannelSubscription{}).Where("guild_id = ? and enabled = ?", sub.GuildID, common.BoolToPointer(true)).Count(&count)
if count >= MaxFeedsForContext(ctx) {
var currFeed ChannelSubscription
err := common.GORM.Model(&ChannelSubscription{}).Where("id = ?", sub.ID).First(&currFeed)
if err != nil {
logger.WithError(err.Error).Errorf("Failed getting feed %d", sub.ID)
}
if !*currFeed.Enabled && *sub.Enabled {
return templateData.AddAlerts(web.ErrorAlert(fmt.Sprintf("Max %d enabled youtube feeds allowed (%d for premium servers)", GuildMaxFeeds, GuildMaxFeedsPremium))), nil
}
}

err = common.GORM.Save(sub).Error
if err == nil {
Expand Down

0 comments on commit 2d78392

Please sign in to comment.