Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmd/dequeuer/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ func main() {
statsdAddress string
apiKind string
adminPort int
workers int
)
flag.StringVar(&clusterConfigPath, "cluster-config", "", "cluster config path")
flag.StringVar(&clusterUID, "cluster-uid", "", "cluster unique identifier")
Expand All @@ -61,6 +62,7 @@ func main() {
flag.StringVar(&statsdAddress, "statsd-address", "", "address to push statsd metrics")
flag.IntVar(&userContainerPort, "user-port", 8080, "target port to which the dequeued messages will be sent to")
flag.IntVar(&adminPort, "admin-port", 0, "port where the admin server (for the probes) will be exposed")
flag.IntVar(&workers, "workers", 1, "number of workers pulling from the queue")

flag.Parse()

Expand Down Expand Up @@ -166,6 +168,7 @@ func main() {
Region: clusterConfig.Region,
QueueURL: queueURL,
StopIfNoMessages: true,
Workers: workers,
}

case userconfig.AsyncAPIKind.String():
Expand All @@ -186,6 +189,7 @@ func main() {
Region: clusterConfig.Region,
QueueURL: queueURL,
StopIfNoMessages: false,
Workers: workers,
}

// report prometheus metrics for async api kinds
Expand Down
1 change: 1 addition & 0 deletions docs/workloads/async/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
kind: AsyncAPI # must be "AsyncAPI" for async APIs (required)
pod: # pod configuration (required)
port: <int> # port to which requests will be sent (default: 8080; exported as $CORTEX_PORT)
max_concurrency: <int> # maximum number of requests that will be concurrently sent into the container (default: 1, max allowed: 100)
containers: # configurations for the containers to run (at least one constainer must be provided)
- name: <string> # name of the container (required)
image: <string> # docker image to use for the container (required)
Expand Down
6 changes: 2 additions & 4 deletions pkg/dequeuer/batch_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func (h *BatchMessageHandler) handleBatch(message *sqs.Message) error {
return nil
}

endTime := time.Now().Sub(startTime)
endTime := time.Since(startTime)

err = h.recordSuccess()
if err != nil {
Expand All @@ -175,7 +175,7 @@ func (h *BatchMessageHandler) handleBatch(message *sqs.Message) error {
func (h *BatchMessageHandler) onJobComplete(message *sqs.Message) error {
shouldRunOnJobComplete := false
h.log.Info("received job_complete message")
for true {
for {
queueAttributes, err := GetQueueAttributes(h.aws, h.config.QueueURL)
if err != nil {
return err
Expand Down Expand Up @@ -223,8 +223,6 @@ func (h *BatchMessageHandler) onJobComplete(message *sqs.Message) error {

time.Sleep(h.jobCompleteMessageDelay)
}

return nil
}

func isOnJobCompleteMessage(message *sqs.Message) bool {
Expand Down
35 changes: 31 additions & 4 deletions pkg/dequeuer/dequeuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ import (
"github.com/aws/aws-sdk-go/service/sqs"
awslib "github.com/cortexlabs/cortex/pkg/lib/aws"
"github.com/cortexlabs/cortex/pkg/lib/errors"
"github.com/cortexlabs/cortex/pkg/lib/math"
"github.com/cortexlabs/cortex/pkg/lib/telemetry"
"go.uber.org/zap"
)
Expand All @@ -40,6 +41,7 @@ type SQSDequeuerConfig struct {
Region string
QueueURL string
StopIfNoMessages bool
Workers int
}

type SQSDequeuer struct {
Expand Down Expand Up @@ -96,12 +98,37 @@ func (d *SQSDequeuer) ReceiveMessage() (*sqs.Message, error) {
}

func (d *SQSDequeuer) Start(messageHandler MessageHandler, readinessProbeFunc func() bool) error {
numWorkers := math.MaxInt(d.config.Workers, 1)

d.log.Infof("Starting %d workers", numWorkers)
errCh := make(chan error)
doneChs := make([]chan struct{}, d.config.Workers)
for i := 0; i < numWorkers; i++ {
doneChs[i] = make(chan struct{})
go func(i int) {
errCh <- d.worker(messageHandler, readinessProbeFunc, doneChs[i])
}(i)
}

select {
case err := <-errCh:
return err
case <-d.done:
for _, doneCh := range doneChs {
doneCh <- struct{}{}
}
}

return <-errCh
}

func (d SQSDequeuer) worker(messageHandler MessageHandler, readinessProbeFunc func() bool, workerDone chan struct{}) error {
noMessagesInPreviousIteration := false

loop:
for {
select {
case <-d.done:
case <-workerDone:
break loop
default:
if !readinessProbeFunc() {
Expand Down Expand Up @@ -134,8 +161,8 @@ loop:

noMessagesInPreviousIteration = false
receiptHandle := *message.ReceiptHandle
done := d.StartMessageRenewer(receiptHandle)
err = d.handleMessage(message, messageHandler, done)
renewerDone := d.StartMessageRenewer(receiptHandle)
err = d.handleMessage(message, messageHandler, renewerDone)
if err != nil {
d.log.Error(err)
telemetry.Error(err)
Expand Down Expand Up @@ -196,7 +223,7 @@ func (d *SQSDequeuer) StartMessageRenewer(receiptHandle string) chan struct{} {
startTime := time.Now()
go func() {
defer ticker.Stop()
for true {
for {
select {
case <-done:
return
Expand Down
75 changes: 75 additions & 0 deletions pkg/dequeuer/dequeuer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/aws/aws-sdk-go/service/sqs"
awslib "github.com/cortexlabs/cortex/pkg/lib/aws"
"github.com/cortexlabs/cortex/pkg/lib/random"
"github.com/cortexlabs/cortex/pkg/lib/sets/strset"
"github.com/ory/dockertest/v3"
dc "github.com/ory/dockertest/v3/docker"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -179,6 +180,7 @@ func TestSQSDequeuer_ReceiveMessage(t *testing.T) {
Region: _localStackDefaultRegion,
QueueURL: queueURL,
StopIfNoMessages: true,
Workers: 1,
}, awsClient, logger,
)
require.NoError(t, err)
Expand All @@ -205,6 +207,7 @@ func TestSQSDequeuer_StartMessageRenewer(t *testing.T) {
Region: _localStackDefaultRegion,
QueueURL: queueURL,
StopIfNoMessages: true,
Workers: 1,
}, awsClient, logger,
)
require.NoError(t, err)
Expand Down Expand Up @@ -253,6 +256,7 @@ func TestSQSDequeuerTerminationOnEmptyQueue(t *testing.T) {
Region: _localStackDefaultRegion,
QueueURL: queueURL,
StopIfNoMessages: true,
Workers: 1,
}, awsClient, logger,
)
require.NoError(t, err)
Expand Down Expand Up @@ -303,6 +307,7 @@ func TestSQSDequeuer_Shutdown(t *testing.T) {
Region: _localStackDefaultRegion,
QueueURL: queueURL,
StopIfNoMessages: true,
Workers: 1,
}, awsClient, logger,
)
require.NoError(t, err)
Expand Down Expand Up @@ -345,6 +350,7 @@ func TestSQSDequeuer_Start_HandlerError(t *testing.T) {
Region: _localStackDefaultRegion,
QueueURL: queueURL,
StopIfNoMessages: true,
Workers: 1,
}, awsClient, logger,
)
require.NoError(t, err)
Expand Down Expand Up @@ -383,3 +389,72 @@ func TestSQSDequeuer_Start_HandlerError(t *testing.T) {
return msg != nil
}, 5*time.Second, time.Second)
}

func TestSQSDequeuer_MultipleWorkers(t *testing.T) {
t.Parallel()

awsClient := testAWSClient(t)
queueURL := createQueue(t, awsClient)

numMessages := 3
expectedMsgs := make([]string, numMessages)
for i := 0; i < numMessages; i++ {
message := fmt.Sprintf("%d", i)
expectedMsgs[i] = message
_, err := awsClient.SQS().SendMessage(&sqs.SendMessageInput{
MessageBody: aws.String(message),
MessageDeduplicationId: aws.String(message),
MessageGroupId: aws.String(message),
QueueUrl: aws.String(queueURL),
})
require.NoError(t, err)
}

logger := newLogger(t)
defer func() { _ = logger.Sync() }()

dq, err := NewSQSDequeuer(
SQSDequeuerConfig{
Region: _localStackDefaultRegion,
QueueURL: queueURL,
StopIfNoMessages: true,
Workers: numMessages,
}, awsClient, logger,
)
require.NoError(t, err)

dq.waitTimeSeconds = aws.Int64(0)
dq.notFoundSleepTime = 0

msgCh := make(chan string, numMessages)
handler := NewMessageHandlerFunc(
func(message *sqs.Message) error {
msgCh <- *message.Body
return nil
},
)

errCh := make(chan error)
go func() {
errCh <- dq.Start(handler, func() bool { return true })
}()

receivedMessages := make([]string, numMessages)
for i := 0; i < numMessages; i++ {
receivedMessages[i] = <-msgCh
}
dq.Shutdown()

// timeout test after 10 seconds
time.AfterFunc(10*time.Second, func() {
close(msgCh)
errCh <- errors.New("test timed out")
})

require.Len(t, receivedMessages, numMessages)

set := strset.FromSlice(receivedMessages)
require.True(t, set.Has(expectedMsgs...))

require.NoError(t, <-errCh)
}
6 changes: 4 additions & 2 deletions pkg/dequeuer/probes.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,12 @@ func ProbesFromFile(probesPath string, logger *zap.SugaredLogger) ([]*probe.Prob
return nil, err
}

var probesSlice []*probe.Probe
probesSlice := make([]*probe.Probe, len(probesMap))
var i int
for _, p := range probesMap {
auxProbe := p
probesSlice = append(probesSlice, probe.NewProbe(&auxProbe, logger))
probesSlice[i] = probe.NewProbe(&auxProbe, logger)
i++
}
return probesSlice, nil
}
Expand Down
13 changes: 13 additions & 0 deletions pkg/types/spec/validations.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,19 @@ func podValidation(kind userconfig.Kind) *cr.StructFieldValidation {
)
}

if kind == userconfig.AsyncAPIKind {
validation.StructValidation.StructFieldValidations = append(validation.StructValidation.StructFieldValidations,
&cr.StructFieldValidation{
StructField: "MaxConcurrency",
Int64Validation: &cr.Int64Validation{
Default: consts.DefaultMaxConcurrency,
GreaterThan: pointer.Int64(0),
LessThanOrEqualTo: pointer.Int64(100),
},
},
)
}

return validation
}

Expand Down
23 changes: 12 additions & 11 deletions pkg/workloads/k8s.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ func asyncDequeuerProxyContainer(api spec.API, queueURL string) (kcore.Container
"--statsd-address", _statsdAddress,
"--user-port", s.Int32(*api.Pod.Port),
"--admin-port", consts.AdminPortStr,
"--workers", s.Int64(api.Pod.MaxConcurrency),
},
Env: BaseEnvVars,
EnvFrom: BaseClusterEnvVars(),
Expand Down Expand Up @@ -366,8 +367,8 @@ func userPodContainers(api spec.API) ([]kcore.Container, []kcore.Volume) {
ClientConfigMount(),
}

var containers []kcore.Container
for _, container := range api.Pod.Containers {
containers := make([]kcore.Container, len(api.Pod.Containers))
for i, container := range api.Pod.Containers {
containerResourceList := kcore.ResourceList{}
containerResourceLimitsList := kcore.ResourceList{}
securityContext := kcore.SecurityContext{
Expand Down Expand Up @@ -433,7 +434,7 @@ func userPodContainers(api spec.API) ([]kcore.Container, []kcore.Volume) {
})
}

containers = append(containers, kcore.Container{
containers[i] = kcore.Container{
Name: container.Name,
Image: container.Image,
Command: container.Command,
Expand All @@ -448,7 +449,7 @@ func userPodContainers(api spec.API) ([]kcore.Container, []kcore.Volume) {
},
ImagePullPolicy: kcore.PullAlways,
SecurityContext: &securityContext,
})
}
}

return containers, volumes
Expand Down Expand Up @@ -498,18 +499,17 @@ func GenerateNodeAffinities(apiNodeGroups []string) *kcore.Affinity {
nodeGroups = config.ClusterConfig.NodeGroups
}

var requiredNodeGroups []string
var preferredAffinities []kcore.PreferredSchedulingTerm

for _, nodeGroup := range nodeGroups {
requiredNodeGroups := make([]string, len(nodeGroups))
preferredAffinities := make([]kcore.PreferredSchedulingTerm, len(nodeGroups))
for i, nodeGroup := range nodeGroups {
var nodeGroupPrefix string
if nodeGroup.Spot {
nodeGroupPrefix = "cx-ws-"
} else {
nodeGroupPrefix = "cx-wd-"
}

preferredAffinities = append(preferredAffinities, kcore.PreferredSchedulingTerm{
preferredAffinities[i] = kcore.PreferredSchedulingTerm{
Weight: int32(nodeGroup.Priority),
Preference: kcore.NodeSelectorTerm{
MatchExpressions: []kcore.NodeSelectorRequirement{
Expand All @@ -520,8 +520,9 @@ func GenerateNodeAffinities(apiNodeGroups []string) *kcore.Affinity {
},
},
},
})
requiredNodeGroups = append(requiredNodeGroups, nodeGroupPrefix+nodeGroup.Name)
}

requiredNodeGroups[i] = nodeGroupPrefix + nodeGroup.Name
}

var requiredNodeSelector *kcore.NodeSelector
Expand Down