Skip to content

Commit

Permalink
fix: populate task resources as deployment progresses (#3232)
Browse files Browse the repository at this point in the history
Previously, we populated all dependencies at the start of the
deployment, meaning any intermediate outputs were not visible to
followup tasks
  • Loading branch information
jvmakine authored Oct 29, 2024
1 parent 6125749 commit 6aa6e36
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 49 deletions.
35 changes: 21 additions & 14 deletions backend/provisioner/deployment.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,12 @@ const (

// Task is a unit of work for a deployment
type Task struct {
handler provisionerconnect.ProvisionerPluginServiceClient
module string
state TaskState
desired *ResourceGraph
existing *ResourceGraph
// populated only when the task is done
output []*provisioner.Resource
handler provisionerconnect.ProvisionerPluginServiceClient
module string
state TaskState
desired *ResourceGraph

deployment *Deployment

// set if the task is currently running
runningToken string
Expand All @@ -42,27 +41,33 @@ func (t *Task) Start(ctx context.Context) error {
}
t.state = TaskStateRunning

ids := map[string]bool{}
for _, res := range t.desired.Roots() {
ids[res.ResourceId] = true
}

resp, err := t.handler.Provision(ctx, connect.NewRequest(&provisioner.ProvisionRequest{
Module: t.module,
// TODO: We need a proper cluster specific ID here
FtlClusterId: "ftl",
ExistingResources: t.existing.Roots(),
DesiredResources: t.constructResourceContext(t.desired),
ExistingResources: t.deployment.Graph.ByIDs(ids),
DesiredResources: t.constructResourceContext(t.desired.Roots(), t.deployment.Graph),
}))
if err != nil {
t.state = TaskStateFailed
return fmt.Errorf("error provisioning resources: %w", err)
}
t.runningToken = resp.Msg.ProvisioningToken

return nil
}

func (t *Task) constructResourceContext(r *ResourceGraph) []*provisioner.ResourceContext {
result := make([]*provisioner.ResourceContext, len(r.Roots()))
for i, res := range r.Roots() {
func (t *Task) constructResourceContext(resources []*provisioner.Resource, state *ResourceGraph) []*provisioner.ResourceContext {
result := make([]*provisioner.ResourceContext, len(resources))
for i, res := range resources {
result[i] = &provisioner.ResourceContext{
Resource: res,
Dependencies: r.Dependencies(res),
Dependencies: state.Dependencies(res.ResourceId),
}
}
return result
Expand All @@ -89,7 +94,7 @@ func (t *Task) Progress(ctx context.Context) error {
}
if succ, ok := resp.Msg.Status.(*provisioner.StatusResponse_Success); ok {
t.state = TaskStateDone
t.output = succ.Success.UpdatedResources
t.deployment.Graph.Update(succ.Success.UpdatedResources)
return nil
}
time.Sleep(retry.Duration())
Expand All @@ -100,6 +105,8 @@ func (t *Task) Progress(ctx context.Context) error {
type Deployment struct {
Module string
Tasks []*Task
// Graph is the current state of the resources affected by the deployment
Graph *ResourceGraph
}

// next running or pending task. Nil if all tasks are done.
Expand Down
99 changes: 90 additions & 9 deletions backend/provisioner/deployment_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package provisioner_test

import (
"context"
"fmt"
"testing"

"connectrpc.com/connect"
Expand All @@ -11,30 +12,38 @@ import (
"github.com/TBD54566975/ftl/backend/provisioner"
"github.com/TBD54566975/ftl/internal/log"
"github.com/alecthomas/assert/v2"
"github.com/google/uuid"
)

// MockProvisioner is a mock implementation of the Provisioner interface
type MockProvisioner struct {
Token string
StatusFn func(ctx context.Context, req *proto.StatusRequest) (*proto.StatusResponse, error)
ProvisionFn func(ctx context.Context, req *proto.ProvisionRequest) (*proto.ProvisionResponse, error)

stateCalls int
}

var _ provisionerconnect.ProvisionerPluginServiceClient = (*MockProvisioner)(nil)

// Ping implements provisionerconnect.ProvisionerPluginServiceClient.
func (m *MockProvisioner) Ping(context.Context, *connect.Request[ftlv1.PingRequest]) (*connect.Response[ftlv1.PingResponse], error) {
return &connect.Response[ftlv1.PingResponse]{}, nil
}

// Plan implements provisionerconnect.ProvisionerPluginServiceClient.
func (m *MockProvisioner) Plan(context.Context, *connect.Request[proto.PlanRequest]) (*connect.Response[proto.PlanResponse], error) {
panic("unimplemented")
}

// Provision implements provisionerconnect.ProvisionerPluginServiceClient.
func (m *MockProvisioner) Provision(context.Context, *connect.Request[proto.ProvisionRequest]) (*connect.Response[proto.ProvisionResponse], error) {
func (m *MockProvisioner) Provision(ctx context.Context, req *connect.Request[proto.ProvisionRequest]) (*connect.Response[proto.ProvisionResponse], error) {
if m.ProvisionFn != nil {
resp, err := m.ProvisionFn(ctx, req.Msg)
if err != nil {
return nil, err
}
return connect.NewResponse(resp), nil
}

return connect.NewResponse(&proto.ProvisionResponse{
ProvisioningToken: m.Token,
ProvisioningToken: uuid.New().String(),
}), nil
}

Expand All @@ -46,6 +55,15 @@ func (m *MockProvisioner) Status(ctx context.Context, req *connect.Request[proto
Status: &proto.StatusResponse_Running{},
}), nil
}

if m.StatusFn != nil {
rep, err := m.StatusFn(ctx, req.Msg)
if err != nil {
return nil, err
}
return connect.NewResponse(rep), nil
}

return connect.NewResponse(&proto.StatusResponse{
Status: &proto.StatusResponse_Success{
Success: &proto.StatusResponse_ProvisioningSuccess{
Expand All @@ -66,9 +84,11 @@ func TestDeployment_Progress(t *testing.T) {
})

t.Run("progresses each provisioner in order", func(t *testing.T) {
mock := &MockProvisioner{}

registry := provisioner.ProvisionerRegistry{}
registry.Register("mock", &MockProvisioner{Token: "foo"}, provisioner.ResourceTypePostgres)
registry.Register("mock", &MockProvisioner{Token: "bar"}, provisioner.ResourceTypeMysql)
registry.Register("mock", mock, provisioner.ResourceTypePostgres)
registry.Register("mock", mock, provisioner.ResourceTypeMysql)

graph := &provisioner.ResourceGraph{}
graph.AddNode(&proto.Resource{ResourceId: "a", Resource: &proto.Resource_Mysql{}})
Expand All @@ -81,7 +101,7 @@ func TestDeployment_Progress(t *testing.T) {
_, err := dpl.Progress(ctx)
assert.NoError(t, err)
assert.Equal(t, 1, len(dpl.State().Pending))
assert.NotZero(t, dpl.State().Done)
assert.NotEqual(t, 0, len(dpl.State().Done))

_, err = dpl.Progress(ctx)
assert.NoError(t, err)
Expand All @@ -92,4 +112,65 @@ func TestDeployment_Progress(t *testing.T) {
assert.Equal(t, 2, len(dpl.State().Done))
assert.False(t, running)
})

t.Run("uses output of previous task in a follow up task", func(t *testing.T) {
dbMock := &MockProvisioner{
StatusFn: func(ctx context.Context, req *proto.StatusRequest) (*proto.StatusResponse, error) {
if psql, ok := req.DesiredResources[0].Resource.(*proto.Resource_Postgres); ok {
if psql.Postgres == nil {
psql.Postgres = &proto.PostgresResource{}
}
if psql.Postgres.Output == nil {
psql.Postgres.Output = &proto.PostgresResource_PostgresResourceOutput{}
}
psql.Postgres.Output.ReadDsn = "postgres://localhost:5432/foo"
} else {
return nil, fmt.Errorf("expected postgres resource, got %T", req.DesiredResources[0].Resource)
}

return &proto.StatusResponse{
Status: &proto.StatusResponse_Success{
Success: &proto.StatusResponse_ProvisioningSuccess{
UpdatedResources: req.DesiredResources,
},
},
}, nil
},
}

moduleMock := &MockProvisioner{
ProvisionFn: func(ctx context.Context, req *proto.ProvisionRequest) (*proto.ProvisionResponse, error) {
for _, res := range req.DesiredResources {
for _, dep := range res.Dependencies {
if psql, ok := dep.Resource.(*proto.Resource_Postgres); ok && psql.Postgres != nil {
if psql.Postgres.Output == nil || psql.Postgres.Output.ReadDsn == "" {
return nil, fmt.Errorf("read dsn is empty")
}
}
}
}
return &proto.ProvisionResponse{
ProvisioningToken: uuid.New().String(),
}, nil
},
}

registry := provisioner.ProvisionerRegistry{}
registry.Register("mockdb", dbMock, provisioner.ResourceTypePostgres)
registry.Register("mockmod", moduleMock, provisioner.ResourceTypeModule)

// Check that the deployment finishes without errors
graph := &provisioner.ResourceGraph{}
graph.AddNode(&proto.Resource{ResourceId: "db", Resource: &proto.Resource_Postgres{}})
graph.AddNode(&proto.Resource{ResourceId: "mod", Resource: &proto.Resource_Module{}})

dpl := registry.CreateDeployment(ctx, "test-module", graph, &provisioner.ResourceGraph{})

running := true
for running {
r, err := dpl.Progress(ctx)
assert.NoError(t, err)
running = r
}
})
}
23 changes: 13 additions & 10 deletions backend/provisioner/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -135,28 +135,31 @@ func (reg *ProvisionerRegistry) Register(id string, handler provisionerconnect.P
func (reg *ProvisionerRegistry) CreateDeployment(ctx context.Context, module string, desiredResources, existingResources *ResourceGraph) *Deployment {
logger := log.FromContext(ctx)

var result []*Task

existingByHandler := reg.groupByProvisioner(existingResources.Resources())
desiredByHandler := reg.groupByProvisioner(desiredResources.Resources())

deployment := &Deployment{
Module: module,
Graph: desiredResources,
}

for _, binding := range reg.listProvisioners() {
desired := desiredByHandler[binding.Provisioner]
existing := existingByHandler[binding.Provisioner]

if !resourcesEqual(desired, existing) {
logger.Debugf("Adding task for module %s: %s", module, binding)
result = append(result, &Task{
module: module,
handler: binding.Provisioner,
desired: desiredResources.WithDirectDependencies(desired),
existing: existingResources.WithDirectDependencies(existing),
deployment.Tasks = append(deployment.Tasks, &Task{
module: module,
handler: binding.Provisioner,
deployment: deployment,
desired: desiredResources.WithDirectDependencies(desired),
})
} else {
logger.Debugf("Skipping task for module %s with provisioner %s", module, binding.ID)
}
}
return &Deployment{Tasks: result, Module: module}
return deployment
}

func resourcesEqual(desired, existing []*provisioner.Resource) bool {
Expand Down Expand Up @@ -231,8 +234,8 @@ func ExtractResources(msg *ftlv1.CreateDeploymentRequest) (*ResourceGraph, error
edges := make([]*ResourceEdge, len(deps))
for i, dep := range deps {
edges[i] = &ResourceEdge{
from: root,
to: dep,
from: root.ResourceId,
to: dep.ResourceId,
}
}

Expand Down
Loading

0 comments on commit 6aa6e36

Please sign in to comment.