Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ Both client libraries are pre-1.0, and they have separate versioning.

## Unreleased

No unreleased changes.
- All Go SDK functions that take a Context will respect the timeout of the context.

## modal-js/v0.5.0, modal-go/v0.5.0

Expand Down
2 changes: 1 addition & 1 deletion modal-go/function.go
Original file line number Diff line number Diff line change
Expand Up @@ -371,7 +371,7 @@ func blobUpload(ctx context.Context, client pb.ModalClientClient, data []byte) (
return "", fmt.Errorf("Function input size exceeds multipart upload threshold, unsupported by this SDK version")

case pb.BlobCreateResponse_UploadUrl_case:
req, err := http.NewRequest("PUT", resp.GetUploadUrl(), bytes.NewReader(data))
req, err := http.NewRequestWithContext(ctx, "PUT", resp.GetUploadUrl(), bytes.NewReader(data))
if err != nil {
return "", fmt.Errorf("failed to create upload request: %w", err)
}
Expand Down
8 changes: 8 additions & 0 deletions modal-go/image.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,10 @@ func (image *Image) Build(ctx context.Context, app *App) (*Image, error) {
var currentImageID string

for i, currentLayer := range image.layers {
if err := ctx.Err(); err != nil {
return nil, err
}

mergedSecrets, err := mergeEnvIntoSecrets(ctx, image.client, &currentLayer.env, &currentLayer.secrets)
if err != nil {
return nil, err
Expand Down Expand Up @@ -255,6 +259,10 @@ func (image *Image) Build(ctx context.Context, app *App) (*Image, error) {
// Not built or in the process of building - wait for build
lastEntryID := ""
for result == nil {
if err := ctx.Err(); err != nil {
return nil, err
}

stream, err := image.client.cpClient.ImageJoinStreaming(ctx, pb.ImageJoinStreamingRequest_builder{
ImageId: resp.GetImageId(),
Timeout: 55,
Expand Down
10 changes: 9 additions & 1 deletion modal-go/invocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ func pollFunctionOutput(ctx context.Context, client pb.ModalClientClient, getOut
}

for {
if err := ctx.Err(); err != nil {
return nil, err
}

output, err := getOutput(ctx, pollTimeout)
if err != nil {
return nil, err
Expand Down Expand Up @@ -246,7 +250,11 @@ func blobDownload(ctx context.Context, client pb.ModalClientClient, blobID strin
if err != nil {
return nil, err
}
s3resp, err := http.Get(resp.GetDownloadUrl())
req, err := http.NewRequestWithContext(ctx, "GET", resp.GetDownloadUrl(), nil)
if err != nil {
return nil, fmt.Errorf("failed to create download request: %w", err)
}
s3resp, err := http.DefaultClient.Do(req)
if err != nil {
return nil, fmt.Errorf("failed to download blob: %w", err)
}
Expand Down
13 changes: 13 additions & 0 deletions modal-go/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ func (q *Queue) get(ctx context.Context, n int, params *QueueGetParams) ([]any,
}

for {
if err := ctx.Err(); err != nil {
return nil, err
}

resp, err := q.client.cpClient.QueueGet(ctx, pb.QueueGetRequest_builder{
QueueId: q.QueueID,
PartitionKey: partitionKey,
Expand Down Expand Up @@ -284,6 +288,10 @@ func (q *Queue) put(ctx context.Context, values []any, params *QueuePutParams) e
}

for {
if err := ctx.Err(); err != nil {
return err
}

_, err := q.client.cpClient.QueuePut(ctx, pb.QueuePutRequest_builder{
QueueId: q.QueueID,
Values: valuesEncoded,
Expand Down Expand Up @@ -400,6 +408,11 @@ func (q *Queue) Iterate(ctx context.Context, params *QueueIterateParams) iter.Se

fetchDeadline := time.Now().Add(itemPoll)
for {
if err := ctx.Err(); err != nil {
yield(nil, err)
return
}

pollDuration := max(0, min(maxPoll, time.Until(fetchDeadline)))
resp, err := q.client.cpClient.QueueNextItems(ctx, pb.QueueNextItemsRequest_builder{
QueueId: q.QueueID,
Expand Down
13 changes: 13 additions & 0 deletions modal-go/sandbox.go
Original file line number Diff line number Diff line change
Expand Up @@ -537,6 +537,10 @@ func (sb *Sandbox) Terminate(ctx context.Context) error {
// Wait blocks until the Sandbox exits.
func (sb *Sandbox) Wait(ctx context.Context) (int, error) {
for {
if err := ctx.Err(); err != nil {
return 0, err
}

resp, err := sb.client.cpClient.SandboxWait(ctx, pb.SandboxWaitRequest_builder{
SandboxId: sb.SandboxID,
Timeout: 10,
Expand Down Expand Up @@ -677,6 +681,11 @@ func (s *sandboxServiceImpl) List(ctx context.Context, params *SandboxListParams
return func(yield func(*Sandbox, error) bool) {
var before float64
for {
if err := ctx.Err(); err != nil {
yield(nil, err)
return
}

resp, err := s.client.cpClient.SandboxList(ctx, pb.SandboxListRequest_builder{
AppId: params.AppID,
BeforeTimestamp: before,
Expand Down Expand Up @@ -764,6 +773,10 @@ func newContainerProcess(cpClient pb.ModalClientClient, execID string, params Sa
// Wait blocks until the container process exits and returns its exit code.
func (cp *ContainerProcess) Wait(ctx context.Context) (int, error) {
for {
if err := ctx.Err(); err != nil {
return 0, err
}

resp, err := cp.cpClient.ContainerExecWait(ctx, pb.ContainerExecWaitRequest_builder{
ExecId: cp.execID,
Timeout: 55,
Expand Down
114 changes: 114 additions & 0 deletions modal-go/test/grpc_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
package test

import (
"context"
"net"
"testing"
"time"

"github.com/modal-labs/libmodal/modal-go"
pb "github.com/modal-labs/libmodal/modal-go/proto/modal_proto"
"github.com/onsi/gomega"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/status"
"google.golang.org/grpc/test/bufconn"
)

type slowModalServer struct {
pb.UnimplementedModalClientServer
sleepDuration time.Duration
}

// AppGetOrCreate is just chosen arbitrarily as a GRPC method to use for testing.
func (s *slowModalServer) AppGetOrCreate(ctx context.Context, req *pb.AppGetOrCreateRequest) (*pb.AppGetOrCreateResponse, error) {
select {
case <-time.After(s.sleepDuration):
return pb.AppGetOrCreateResponse_builder{AppId: req.GetAppName()}.Build(), nil
case <-ctx.Done():
return nil, ctx.Err()
}
}

func (s *slowModalServer) AuthTokenGet(ctx context.Context, req *pb.AuthTokenGetRequest) (*pb.AuthTokenGetResponse, error) {
return pb.AuthTokenGetResponse_builder{Token: "test-token"}.Build(), nil
}

func TestAppFromName_RespectsContextDeadline(t *testing.T) {
t.Parallel()

testCases := []struct {
name string
serverSleep time.Duration
contextTimeout time.Duration
expectTimeout bool
}{
{
name: "deadline exceeded",
serverSleep: 100 * time.Millisecond,
contextTimeout: 10 * time.Millisecond,
expectTimeout: true,
},
{
name: "completes before deadline",
serverSleep: 10 * time.Millisecond,
contextTimeout: 100 * time.Millisecond,
expectTimeout: false,
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
g := gomega.NewWithT(t)

lis := bufconn.Listen(1024 * 1024)

grpcServer := grpc.NewServer()
pb.RegisterModalClientServer(grpcServer, &slowModalServer{
sleepDuration: tc.serverSleep,
})

go func() {
if err := grpcServer.Serve(lis); err != nil {
t.Logf("Server error: %v", err)
}
}()
defer grpcServer.Stop()

bufDialer := func(context.Context, string) (net.Conn, error) {
return lis.Dial()
}

conn, err := grpc.NewClient("passthrough:///bufnet",
grpc.WithContextDialer(bufDialer),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
g.Expect(err).ShouldNot(gomega.HaveOccurred())
defer conn.Close()

client, err := modal.NewClientWithOptions(&modal.ClientParams{
TokenID: "test-token-id",
TokenSecret: "test-token-secret",
Environment: "test",
ControlPlaneClient: pb.NewModalClientClient(conn),
})
g.Expect(err).ShouldNot(gomega.HaveOccurred())

ctxWithTimeout, cancel := context.WithTimeout(context.Background(), tc.contextTimeout)
defer cancel()

app, err := client.Apps.FromName(ctxWithTimeout, "test-app", nil)

if tc.expectTimeout {
g.Expect(err).Should(gomega.HaveOccurred())
st, ok := status.FromError(err)
g.Expect(ok).To(gomega.BeTrue())
g.Expect(st.Code()).To(gomega.Equal(codes.DeadlineExceeded))
} else {
g.Expect(err).ShouldNot(gomega.HaveOccurred())
g.Expect(app.AppID).To(gomega.Equal("test-app"))
}
})
}
}