Skip to content

Commit

Permalink
feature: Add persistance to command scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
maxekman committed Nov 8, 2021
1 parent 3602de4 commit 609ad4d
Show file tree
Hide file tree
Showing 2 changed files with 500 additions and 91 deletions.
336 changes: 301 additions & 35 deletions middleware/commandhandler/scheduler/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,47 +16,18 @@ package scheduler

import (
"context"
"encoding/json"
"errors"
"fmt"
"sync"
"time"

eh "github.com/looplab/eventhorizon"
"github.com/looplab/eventhorizon/uuid"
)

// NewMiddleware returns a new async handling middleware that returns any errors
// on a error channel.
func NewMiddleware() (eh.CommandHandlerMiddleware, chan *Error) {
errCh := make(chan *Error, 20)

return eh.CommandHandlerMiddleware(func(h eh.CommandHandler) eh.CommandHandler {
return eh.CommandHandlerFunc(func(ctx context.Context, cmd eh.Command) error {
// Delayed command execution if there is time set.
if c, ok := cmd.(Command); ok && !c.ExecuteAt().IsZero() {
go func() {
t := time.NewTimer(time.Until(c.ExecuteAt()))
defer t.Stop()

var err error
select {
case <-ctx.Done():
err = ctx.Err()
case <-t.C:
err = h.HandleCommand(ctx, cmd)
}

if err != nil {
// Always try to deliver errors.
errCh <- &Error{err, ctx, cmd}
}
}()

return nil
}

// Immediate command execution.
return h.HandleCommand(ctx, cmd)
})
}), errCh
}
// ErrCanceled is when a scheduled command has been canceled.
var ErrCanceled = errors.New("canceled")

// Command is a scheduled command with an execution time.
type Command interface {
Expand All @@ -82,6 +53,301 @@ func (c *command) ExecuteAt() time.Time {
return c.t
}

// NewMiddleware returns a new async handling middleware that returns any errors
// on a error channel.
func NewMiddleware(repo eh.ReadWriteRepo) (eh.CommandHandlerMiddleware, *Scheduler) {
s := &Scheduler{
cmdCh: make(chan *scheduledCommand, 100),
cancelScheduling: map[uuid.UUID]chan struct{}{},
repo: repo,
errCh: make(chan error, 100),
}

return eh.CommandHandlerMiddleware(func(h eh.CommandHandler) eh.CommandHandler {
s.setHandler(h)

return eh.CommandHandlerFunc(func(ctx context.Context, cmd eh.Command) error {
// Delayed command execution if there is time set.
if c, ok := cmd.(Command); ok && !c.ExecuteAt().IsZero() {
// Use the wrapped command when created by the helper func.
innerCmd, ok := c.(*command)
if ok {
cmd = innerCmd.Command
}

// Ignore the persisted command ID in this case.
_, err := s.ScheduleCommand(ctx, cmd, c.ExecuteAt().UTC())

return err
}

// Immediate command execution.
return h.HandleCommand(ctx, cmd)
})
}), s
}

// PersistedCommand is a persisted command.
type PersistedCommand struct {
ID uuid.UUID `json:"_" bson:"_id"`
IDStr string `json:"id" bson:"_"`
CommandType eh.CommandType `json:"command_type" bson:"command_type"`
Command eh.Command `json:"-" bson:"-"`
RawCommand json.RawMessage `json:"command,omitempty" bson:"command,omitempty"`
ExecuteAt time.Time `json:"timestamp" bson:"timestamp"`
Context map[string]interface{} `json:"context" bson:"context"`
}

// EntityID implements the EntityID method of the eventhorizon.Entity interface.
func (c *PersistedCommand) EntityID() uuid.UUID {
return c.ID
}

// Scheduler is a scheduled of commands.
type Scheduler struct {
h eh.CommandHandler
hMu sync.Mutex
cmdCh chan *scheduledCommand
cancelScheduling map[uuid.UUID]chan struct{}
cancelSchedulingMu sync.Mutex
repo eh.ReadWriteRepo
errCh chan error
cctx context.Context
cancel context.CancelFunc
done chan struct{}
}

func (s *Scheduler) setHandler(h eh.CommandHandler) {
s.hMu.Lock()
defer s.hMu.Unlock()

s.h = h
}

// Start starts the scheduler by first loading all persisted commands.
func (s *Scheduler) Start() error {
if err := s.loadCommands(); err != nil {
return fmt.Errorf("could not load commands: %w", err)
}

s.cctx, s.cancel = context.WithCancel(context.Background())
s.done = make(chan struct{})

go s.run()

return nil
}

// Stop stops all scheduled commands.
func (s *Scheduler) Stop() error {
s.cancel()

<-s.done

return nil
}

// Errors returns an error channel that will receive errors from handling of
// scheduled commands.
func (s *Scheduler) Errors() <-chan error {
return s.errCh
}

type scheduledCommand struct {
id uuid.UUID
ctx context.Context
cmd eh.Command
executeAt time.Time
}

// ScheduleCommand schedules a command to be executed at time t. It is persisted
// to the repo.
func (s *Scheduler) ScheduleCommand(ctx context.Context, cmd eh.Command, t time.Time) (uuid.UUID, error) {
b, err := json.Marshal(cmd)
if err != nil {
return uuid.Nil, &Error{
Err: fmt.Errorf("could not marshal command: %w", err),
Ctx: ctx,
Command: cmd,
}
}

// Use the command ID as persisted ID if available.
var id uuid.UUID
if cmd, ok := cmd.(eh.CommandIDer); ok {
id = cmd.CommandID()
} else {
id = uuid.New()
}

pc := &PersistedCommand{
ID: id,
CommandType: cmd.CommandType(),
RawCommand: b,
ExecuteAt: t.UTC(),
}
if err := s.repo.Save(context.Background(), pc); err != nil {
return uuid.Nil, &Error{
Err: fmt.Errorf("could not persist command: %w", err),
Ctx: ctx,
Command: cmd,
}
}

select {
case s.cmdCh <- &scheduledCommand{id, ctx, cmd, t.UTC()}:
default:
return uuid.Nil, &Error{
Err: fmt.Errorf("command queue full"),
Ctx: ctx,
Command: cmd,
}
}

return pc.ID, nil
}

// Commands returns all scheduled commands.
func (s *Scheduler) Commands(ctx context.Context) ([]*PersistedCommand, error) {
entities, err := s.repo.FindAll(ctx)
if err != nil {
return nil, fmt.Errorf("could not load scheduled commands: %w", err)
}

commands := make([]*PersistedCommand, len(entities))

for i, entity := range entities {
c, ok := entity.(*PersistedCommand)
if !ok {
return nil, fmt.Errorf("command is not schedulable: %T", entity)
}

c.Command, err = eh.CreateCommand(c.CommandType)
if err != nil {
return nil, fmt.Errorf("could not create command: %w", err)
}

// TODO: Allow any type of marshaler.
if err := json.Unmarshal(c.RawCommand, &c.Command); err != nil {
return nil, fmt.Errorf("could not unmarshal command: %w", err)
}

if c.IDStr != "" {
id, err := uuid.Parse(c.IDStr)
if err != nil {
return nil, fmt.Errorf("could not parse command ID: %w", err)
}

c.ID = id
}

commands[i] = c
}

return commands, nil
}

// CancelCommand cancels a scheduled command.
func (s *Scheduler) CancelCommand(ctx context.Context, id uuid.UUID) error {
s.cancelSchedulingMu.Lock()
defer s.cancelSchedulingMu.Unlock()

cancel, ok := s.cancelScheduling[id]
if !ok {
return fmt.Errorf("command %s not scheduled", id)
}

close(cancel)

return nil
}

func (s *Scheduler) loadCommands() error {
commands, err := s.Commands(context.Background())
if err != nil {
return fmt.Errorf("could not load scheduled commands: %w", err)
}

for _, pc := range commands {
// Use the persisted context.
ctx := eh.UnmarshalContext(context.Background(), pc.Context)

select {
case s.cmdCh <- &scheduledCommand{pc.ID, ctx, pc.Command, pc.ExecuteAt.UTC()}:
default:
return fmt.Errorf("could not schedule command: %w", err)
}
}

return nil
}

func (s *Scheduler) run() {
var wg sync.WaitGroup

loop:
for {
select {
case <-s.cctx.Done():
break loop
case sc := <-s.cmdCh:
wg.Add(1)

s.cancelSchedulingMu.Lock()
cancel := make(chan struct{})
s.cancelScheduling[sc.id] = cancel
s.cancelSchedulingMu.Unlock()

go func(cancel chan struct{}) {
defer wg.Done()

t := time.NewTimer(time.Until(sc.executeAt))
defer t.Stop()

select {
case <-s.cctx.Done():
// Stop without removing persisted cmd.
case <-t.C:
if err := s.h.HandleCommand(sc.ctx, sc.cmd); err != nil {
// Always try to deliver errors.
s.errCh <- &Error{
Err: err,
Ctx: sc.ctx,
Command: sc.cmd,
}
}

if err := s.repo.Remove(context.Background(), sc.id); err != nil {
s.errCh <- &Error{
Err: fmt.Errorf("could not remove persisted command: %w", err),
Ctx: sc.ctx,
Command: sc.cmd,
}
}
case <-cancel:
if err := s.repo.Remove(context.Background(), sc.id); err != nil {
s.errCh <- &Error{
Err: fmt.Errorf("could not remove persisted command: %w", err),
Ctx: sc.ctx,
Command: sc.cmd,
}
}

s.errCh <- &Error{
Err: ErrCanceled,
Ctx: sc.ctx,
Command: sc.cmd,
}
}
}(cancel)
}
}

wg.Wait()

close(s.done)
}

// Error is an async error containing the error and the command.
type Error struct {
// Err is the error that happened when handling the command.
Expand Down
Loading

0 comments on commit 609ad4d

Please sign in to comment.