Skip to content

Commit

Permalink
[Heartbeat]: limit parallelization of tasks by jobtype (#27160)
Browse files Browse the repository at this point in the history
* [Heartbeat]: limit parallelization of tasks by jobtype

* handle release and add tests

* fix linting

* add guard aganist nil in tests

* address review comments

* add test when limit is not specified
  • Loading branch information
vigneshshanmugam authored Aug 11, 2021
1 parent 07b546d commit 70402c6
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 42 deletions.
3 changes: 2 additions & 1 deletion heartbeat/beater/heartbeat.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,9 @@ func New(b *beat.Beat, rawConfig *common.Config) (beat.Beater, error) {
if err != nil {
return nil, err
}
jobConfig := parsedConfig.Jobs

scheduler := scheduler.NewWithLocation(limit, hbregistry.SchedulerRegistry, location)
scheduler := scheduler.NewWithLocation(limit, hbregistry.SchedulerRegistry, location, jobConfig)

bt := &Heartbeat{
done: make(chan struct{}),
Expand Down
5 changes: 5 additions & 0 deletions heartbeat/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ type Config struct {
Scheduler Scheduler `config:"scheduler"`
Autodiscover *autodiscover.Config `config:"autodiscover"`
SyntheticSuites []*common.Config `config:"synthetic_suites"`
Jobs map[string]JobLimit `config:"jobs"`
}

type JobLimit struct {
Limit int64 `config:"limit" validate:"min=0"`
}

// Scheduler defines the syntax of a heartbeat.yml scheduler block.
Expand Down
2 changes: 1 addition & 1 deletion heartbeat/monitors/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ func (t *configuredJob) Start() {
}

tf := t.makeSchedulerTaskFunc()
t.cancelFn, err = t.monitor.scheduler.Add(t.config.Schedule, t.monitor.stdFields.ID, tf)
t.cancelFn, err = t.monitor.scheduler.Add(t.config.Schedule, t.monitor.stdFields.ID, tf, t.config.Type)
if err != nil {
logp.Err("could not start monitor: %v", err)
}
Expand Down
70 changes: 45 additions & 25 deletions heartbeat/scheduler/scheduler.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"golang.org/x/sync/semaphore"

"github.com/elastic/beats/v7/heartbeat/config"
"github.com/elastic/beats/v7/heartbeat/scheduler/timerqueue"
"github.com/elastic/beats/v7/libbeat/common/atomic"
"github.com/elastic/beats/v7/libbeat/logp"
Expand All @@ -46,14 +47,15 @@ var ErrInvalidTransition = fmt.Errorf("invalid state transition")

// Scheduler represents our async timer based scheduler.
type Scheduler struct {
limit int64
limitSem *semaphore.Weighted
state atomic.Int
location *time.Location
timerQueue *timerqueue.TimerQueue
ctx context.Context
cancelCtx context.CancelFunc
stats schedulerStats
limit int64
limitSem *semaphore.Weighted
state atomic.Int
location *time.Location
timerQueue *timerqueue.TimerQueue
ctx context.Context
cancelCtx context.CancelFunc
stats schedulerStats
jobLimitSem map[string]*semaphore.Weighted
}

type schedulerStats struct {
Expand All @@ -77,13 +79,23 @@ type Schedule interface {
RunOnInit() bool
}

func getJobLimitSem(jobLimitByType map[string]config.JobLimit) map[string]*semaphore.Weighted {
jobLimitSem := map[string]*semaphore.Weighted{}
for jobType, jobLimit := range jobLimitByType {
if jobLimit.Limit > 0 {
jobLimitSem[jobType] = semaphore.NewWeighted(jobLimit.Limit)
}
}
return jobLimitSem
}

// New creates a new Scheduler
func New(limit int64, registry *monitoring.Registry) *Scheduler {
return NewWithLocation(limit, registry, time.Local)
return NewWithLocation(limit, registry, time.Local, nil)
}

// NewWithLocation creates a new Scheduler using the given runAt zone.
func NewWithLocation(limit int64, registry *monitoring.Registry, location *time.Location) *Scheduler {
func NewWithLocation(limit int64, registry *monitoring.Registry, location *time.Location, jobLimitByType map[string]config.JobLimit) *Scheduler {
ctx, cancelCtx := context.WithCancel(context.Background())

if limit < 1 {
Expand All @@ -96,14 +108,14 @@ func NewWithLocation(limit int64, registry *monitoring.Registry, location *time.
waitingTasksGauge := monitoring.NewUint(registry, "tasks.waiting")

sched := &Scheduler{
limit: limit,
location: location,
state: atomic.MakeInt(statePreRunning),
ctx: ctx,
cancelCtx: cancelCtx,
limitSem: semaphore.NewWeighted(limit),

timerQueue: timerqueue.NewTimerQueue(ctx),
limit: limit,
location: location,
state: atomic.MakeInt(statePreRunning),
ctx: ctx,
cancelCtx: cancelCtx,
limitSem: semaphore.NewWeighted(limit),
jobLimitSem: getJobLimitSem(jobLimitByType),
timerQueue: timerqueue.NewTimerQueue(ctx),

stats: schedulerStats{
activeJobs: activeJobsGauge,
Expand Down Expand Up @@ -174,7 +186,7 @@ var ErrAlreadyStopped = errors.New("attempted to add job to already stopped sche

// Add adds the given TaskFunc to the current scheduler. Will return an error if the scheduler
// is done.
func (s *Scheduler) Add(sched Schedule, id string, entrypoint TaskFunc) (removeFn context.CancelFunc, err error) {
func (s *Scheduler) Add(sched Schedule, id string, entrypoint TaskFunc, jobType string) (removeFn context.CancelFunc, err error) {
if s.state.Load() == stateStopped {
return nil, ErrAlreadyStopped
}
Expand All @@ -195,7 +207,7 @@ func (s *Scheduler) Add(sched Schedule, id string, entrypoint TaskFunc) (removeF
default:
}
s.stats.activeJobs.Inc()
lastRanAt = s.runRecursiveJob(jobCtx, entrypoint)
lastRanAt = s.runRecursiveJob(jobCtx, entrypoint, jobType)
s.stats.activeJobs.Dec()
s.runOnce(sched.Next(lastRanAt), taskFn)
debugf("Job '%v' returned at %v", id, time.Now())
Expand Down Expand Up @@ -233,10 +245,14 @@ func (s *Scheduler) runOnce(runAt time.Time, taskFn timerqueue.TimerTaskFn) {
// runRecursiveJob runs the entry point for a job, blocking until all subtasks are completed.
// Subtasks are run in separate goroutines.
// returns the time execution began on its first task
func (s *Scheduler) runRecursiveJob(jobCtx context.Context, task TaskFunc) (startedAt time.Time) {
func (s *Scheduler) runRecursiveJob(jobCtx context.Context, task TaskFunc, jobType string) (startedAt time.Time) {
wg := &sync.WaitGroup{}
jobSem := s.jobLimitSem[jobType]
if jobSem != nil {
jobSem.Acquire(jobCtx, 1)
}
wg.Add(1)
startedAt = s.runRecursiveTask(jobCtx, task, wg)
startedAt = s.runRecursiveTask(jobCtx, task, wg, jobSem)
wg.Wait()
return startedAt
}
Expand All @@ -245,7 +261,7 @@ func (s *Scheduler) runRecursiveJob(jobCtx context.Context, task TaskFunc) (star
// Since task funcs can emit continuations recursively we need a function to execute
// recursively.
// The wait group passed into this function expects to already have its count incremented by one.
func (s *Scheduler) runRecursiveTask(jobCtx context.Context, task TaskFunc, wg *sync.WaitGroup) (startedAt time.Time) {
func (s *Scheduler) runRecursiveTask(jobCtx context.Context, task TaskFunc, wg *sync.WaitGroup, jobSem *semaphore.Weighted) (startedAt time.Time) {
defer wg.Done()

// The accounting for waiting/active tasks is done using atomics.
Expand Down Expand Up @@ -279,8 +295,12 @@ func (s *Scheduler) runRecursiveTask(jobCtx context.Context, task TaskFunc, wg *
wg.Add(len(continuations))
for _, cont := range continuations {
// Run continuations in parallel, note that these each will acquire their own slots
// We can discard the started at times for continuations as those are irrelevant
go s.runRecursiveTask(jobCtx, cont, wg)
// We can discard the started at times for continuations as those are
// irrelevant
go s.runRecursiveTask(jobCtx, cont, wg, jobSem)
}
if jobSem != nil && len(continuations) == 0 {
jobSem.Release(1)
}
}

Expand Down
109 changes: 94 additions & 15 deletions heartbeat/scheduler/scheduler_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package scheduler
import (
"context"
"fmt"
"math"
"sync"
"sync/atomic"
"testing"
Expand All @@ -28,6 +29,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/elastic/beats/v7/heartbeat/config"
batomic "github.com/elastic/beats/v7/libbeat/common/atomic"
"github.com/elastic/beats/v7/libbeat/monitoring"
)
Expand All @@ -50,7 +52,7 @@ func TestNew(t *testing.T) {
}

func TestNewWithLocation(t *testing.T) {
scheduler := NewWithLocation(123, monitoring.NewRegistry(), tarawaTime())
scheduler := NewWithLocation(123, monitoring.NewRegistry(), tarawaTime(), nil)
assert.Equal(t, int64(123), scheduler.limit)
assert.Equal(t, tarawaTime(), scheduler.location)
}
Expand Down Expand Up @@ -85,7 +87,7 @@ func testTaskTimes(limit uint32, fn TaskFunc) TaskFunc {
func TestScheduler_Start(t *testing.T) {
// We use tarawa runAt because it could expose some weird runAt math if by accident some code
// relied on the local TZ.
s := NewWithLocation(10, monitoring.NewRegistry(), tarawaTime())
s := NewWithLocation(10, monitoring.NewRegistry(), tarawaTime(), nil)
defer s.Stop()

executed := make(chan string)
Expand All @@ -98,7 +100,7 @@ func TestScheduler_Start(t *testing.T) {
return nil
}
return []TaskFunc{cont}
}))
}), "http")

removedEvents := uint32(1)
// This function will be removed after being invoked once
Expand All @@ -113,7 +115,7 @@ func TestScheduler_Start(t *testing.T) {
}
// Attempt to execute this twice to see if remove() had any effect
removeMtx.Lock()
remove, err := s.Add(testSchedule{}, "removed", testTaskTimes(removedEvents+1, testFn))
remove, err := s.Add(testSchedule{}, "removed", testTaskTimes(removedEvents+1, testFn), "http")
require.NoError(t, err)
require.NotNil(t, remove)
removeMtx.Unlock()
Expand All @@ -128,7 +130,7 @@ func TestScheduler_Start(t *testing.T) {
return nil
}
return []TaskFunc{cont}
}))
}), "http")

received := make([]string, 0)
// We test for a good number of events in this loop because we want to ensure that the remove() took effect
Expand Down Expand Up @@ -160,7 +162,7 @@ func TestScheduler_Start(t *testing.T) {
}

func TestScheduler_Stop(t *testing.T) {
s := NewWithLocation(10, monitoring.NewRegistry(), tarawaTime())
s := NewWithLocation(10, monitoring.NewRegistry(), tarawaTime(), nil)

executed := make(chan struct{})

Expand All @@ -170,7 +172,7 @@ func TestScheduler_Stop(t *testing.T) {
_, err := s.Add(testSchedule{}, "testPostStop", testTaskTimes(1, func(_ context.Context) []TaskFunc {
executed <- struct{}{}
return nil
}))
}), "http")

assert.Equal(t, ErrAlreadyStopped, err)
}
Expand Down Expand Up @@ -208,7 +210,7 @@ func TestScheduler_runRecursiveTask(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
limit := int64(100)
s := NewWithLocation(limit, monitoring.NewRegistry(), tarawaTime())
s := NewWithLocation(limit, monitoring.NewRegistry(), tarawaTime(), nil)

if testCase.overLimit {
s.limitSem.Acquire(context.Background(), limit)
Expand All @@ -224,7 +226,7 @@ func TestScheduler_runRecursiveTask(t *testing.T) {
}

beforeStart := time.Now()
startedAt := s.runRecursiveTask(testCase.jobCtx, tf, wg)
startedAt := s.runRecursiveTask(testCase.jobCtx, tf, wg, nil)

// This will panic in the case where we don't check s.limitSem.Acquire
// for an error value and released an unacquired resource in scheduler.go.
Expand All @@ -240,8 +242,87 @@ func TestScheduler_runRecursiveTask(t *testing.T) {
}
}

func makeTasks(num int, callback func()) TaskFunc {
return func(ctx context.Context) []TaskFunc {
callback()
if num < 1 {
return nil
}
return []TaskFunc{makeTasks(num-1, callback)}
}
}

func TestScheduler_runRecursiveJob(t *testing.T) {
tests := []struct {
name string
numJobs int
limit int64
expect func(events []int)
}{
{
name: "runs more than 1 with limit of 1",
numJobs: 2,
limit: 1,
expect: func(events []int) {
mid := len(events) / 2
firstHalf := events[0:mid]
lastHalf := events[mid:]
for _, ele := range firstHalf {
assert.Equal(t, firstHalf[0], ele)
}
for _, ele := range lastHalf {
assert.Equal(t, lastHalf[0], ele)
}
},
},
{
name: "runs 50 interleaved without limit",
numJobs: 50,
limit: math.MaxInt64,
expect: func(events []int) {
require.GreaterOrEqual(t, len(events), 50)
},
},
{
name: "runs 100 with limit not configured",
numJobs: 100,
limit: 0,
expect: func(events []int) {
require.GreaterOrEqual(t, len(events), 100)
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var jobConfigByType = map[string]config.JobLimit{}
jobType := "http"
if tt.limit > 0 {
jobConfigByType = map[string]config.JobLimit{
jobType: {Limit: tt.limit},
}
}
s := NewWithLocation(math.MaxInt64, monitoring.NewRegistry(), tarawaTime(), jobConfigByType)
var taskArr []int
wg := sync.WaitGroup{}
wg.Add(tt.numJobs)
for i := 0; i < tt.numJobs; i++ {
num := i
tf := makeTasks(4, func() {
taskArr = append(taskArr, num)
})
go func(tff TaskFunc) {
s.runRecursiveJob(context.Background(), tff, jobType)
wg.Done()
}(tf)
}
wg.Wait()
tt.expect(taskArr)
})
}
}

func BenchmarkScheduler(b *testing.B) {
s := NewWithLocation(0, monitoring.NewRegistry(), tarawaTime())
s := NewWithLocation(0, monitoring.NewRegistry(), tarawaTime(), nil)

sched := testSchedule{0}

Expand All @@ -250,7 +331,7 @@ func BenchmarkScheduler(b *testing.B) {
_, err := s.Add(sched, "testPostStop", func(_ context.Context) []TaskFunc {
executed <- struct{}{}
return nil
})
}, "http")
assert.NoError(b, err)
}

Expand All @@ -260,9 +341,7 @@ func BenchmarkScheduler(b *testing.B) {

count := 0
for count < b.N {
select {
case <-executed:
count++
}
<-executed
count++
}
}

0 comments on commit 70402c6

Please sign in to comment.