diff --git a/reddit/plugin_web.go b/reddit/plugin_web.go index 3526512e6d..3552eb4fbf 100644 --- a/reddit/plugin_web.go +++ b/reddit/plugin_web.go @@ -12,6 +12,7 @@ import ( "github.com/jonas747/discordgo" "github.com/jonas747/yagpdb/common" "github.com/jonas747/yagpdb/common/cplogs" + "github.com/jonas747/yagpdb/common/pubsub" "github.com/jonas747/yagpdb/reddit/models" "github.com/jonas747/yagpdb/web" "github.com/volatiletech/sqlboiler/boil" @@ -155,6 +156,10 @@ func HandleNew(w http.ResponseWriter, r *http.Request) interface{} { templateData.AddAlerts(web.SucessAlert("Sucessfully added subreddit feed for /r/" + watchItem.Subreddit)) go cplogs.RetryAddEntry(web.NewLogEntryFromContext(r.Context(), panelLogKeyAddedFeed, &cplogs.Param{Type: cplogs.ParamTypeString, Value: watchItem.Subreddit})) + go pubsub.Publish("reddit_clear_subreddit_cache", -1, PubSubSubredditEventData{ + Subreddit: strings.ToLower(strings.TrimSpace(newElem.Subreddit)), + Slow: newElem.Slow, + }) return templateData } @@ -193,6 +198,10 @@ func HandleModify(w http.ResponseWriter, r *http.Request) interface{} { templateData.AddAlerts(web.SucessAlert("Sucessfully updated reddit feed! :D")) go cplogs.RetryAddEntry(web.NewLogEntryFromContext(r.Context(), panelLogKeyUpdatedFeed, &cplogs.Param{Type: cplogs.ParamTypeString, Value: item.Subreddit})) + go pubsub.Publish("reddit_clear_subreddit_cache", -1, PubSubSubredditEventData{ + Subreddit: strings.ToLower(strings.TrimSpace(item.Subreddit)), + Slow: item.Slow, + }) return templateData } @@ -233,6 +242,10 @@ func HandleRemove(w http.ResponseWriter, r *http.Request) interface{} { templateData["RedditConfig"] = currentConfig go cplogs.RetryAddEntry(web.NewLogEntryFromContext(r.Context(), panelLogKeyRemovedFeed, &cplogs.Param{Type: cplogs.ParamTypeString, Value: item.Subreddit})) + go pubsub.Publish("reddit_clear_subreddit_cache", -1, PubSubSubredditEventData{ + Subreddit: strings.ToLower(strings.TrimSpace(item.Subreddit)), + Slow: item.Slow, + }) return templateData } diff --git a/reddit/reddit.go b/reddit/reddit.go index 802f314062..cb511340c8 100644 --- a/reddit/reddit.go +++ b/reddit/reddit.go @@ -11,6 +11,7 @@ import ( "github.com/jonas747/go-reddit" "github.com/jonas747/yagpdb/common" "github.com/jonas747/yagpdb/common/mqueue" + "github.com/jonas747/yagpdb/common/pubsub" "github.com/jonas747/yagpdb/premium" "github.com/jonas747/yagpdb/reddit/models" ) @@ -79,6 +80,20 @@ func RegisterPlugin() { common.RegisterPlugin(plugin) mqueue.RegisterSource("reddit", plugin) + + pubsub.AddHandler("reddit_clear_subreddit_cache", func(evt *pubsub.Event) { + dataCast := evt.Data.(*PubSubSubredditEventData) + if dataCast.Slow { + configCache.Delete(KeySlowFeeds(strings.ToLower(dataCast.Subreddit))) + } else { + configCache.Delete(KeyFastFeeds(strings.ToLower(dataCast.Subreddit))) + } + }, PubSubSubredditEventData{}) +} + +type PubSubSubredditEventData struct { + Subreddit string `json:"subreddit"` + Slow bool `json:"slow"` } const ( diff --git a/reddit/redditbot.go b/reddit/redditbot.go index 98e78a0722..7169eb81e2 100644 --- a/reddit/redditbot.go +++ b/reddit/redditbot.go @@ -95,6 +95,11 @@ func (p *Plugin) runBot() { feedLock.Unlock() } +type KeySlowFeeds string +type KeyFastFeeds string + +var configCache sync.Map + type PostHandlerImpl struct { Slow bool ratelimiter *Ratelimiter @@ -126,27 +131,57 @@ func (p *PostHandlerImpl) HandleRedditPosts(links []*reddit.Link) { } } -func (p *PostHandlerImpl) handlePost(post *reddit.Link, filterGuild int64) error { +func (p *PostHandlerImpl) getConfigs(subreddit string) ([]*models.RedditFeed, error) { + var key interface{} + key = KeySlowFeeds(subreddit) + if !p.Slow { + key = KeyFastFeeds(subreddit) + } - // createdSince := time.Since(time.Unix(int64(post.CreatedUtc), 0)) - // logger.Printf("[%5.1fs] /r/%-15s: %s, %s", createdSince.Seconds(), post.Subreddit, post.Title, post.ID) + v, ok := configCache.Load(key) + if ok { + return v.(models.RedditFeedSlice), nil + } qms := []qm.QueryMod{ - models.RedditFeedWhere.Subreddit.EQ(strings.ToLower(post.Subreddit)), + models.RedditFeedWhere.Subreddit.EQ(strings.ToLower(subreddit)), models.RedditFeedWhere.Slow.EQ(p.Slow), models.RedditFeedWhere.Disabled.EQ(false), } - if filterGuild > 0 { - qms = append(qms, models.RedditFeedWhere.GuildID.EQ(filterGuild)) + config, err := models.RedditFeeds(qms...).AllG(context.Background()) + if err != nil { + logger.WithError(err).Error("failed retrieving reddit feeds for subreddit") + return nil, err } - config, err := models.RedditFeeds(qms...).AllG(context.Background()) + configCache.Store(key, config) + + return config, nil +} + +func (p *PostHandlerImpl) handlePost(post *reddit.Link, filterGuild int64) error { + + // createdSince := time.Since(time.Unix(int64(post.CreatedUtc), 0)) + // logger.Printf("[%5.1fs] /r/%-15s: %s, %s", createdSince.Seconds(), post.Subreddit, post.Title, post.ID) + + config, err := p.getConfigs(strings.ToLower(post.Subreddit)) if err != nil { logger.WithError(err).Error("failed retrieving reddit feeds for subreddit") return err } + if filterGuild > 0 { + filtered := make([]*models.RedditFeed, 0) + for _, v := range config { + if v.GuildID == filterGuild { + filtered = append(filtered, v) + } + } + + config = filtered + } + // Get the configs that listens to this subreddit, if any filteredItems := p.FilterFeeds(config, post)