diff --git a/grpcweb/client.go b/grpcweb/client.go index 4f3b662..4995c45 100644 --- a/grpcweb/client.go +++ b/grpcweb/client.go @@ -21,6 +21,7 @@ import ( "github.com/mstoykov/k6-taskqueue-lib/taskqueue" "go.k6.io/k6/js/common" "go.k6.io/k6/js/modules" + "go.k6.io/k6/lib/types" "go.k6.io/k6/metrics" "golang.org/x/net/http2" "google.golang.org/grpc" @@ -193,13 +194,19 @@ func (c *client) Invoke(method string, req sobek.Value, params sobek.Value) (*in return nil, fmt.Errorf("request cannot be nil") } - connectReq, ctm, err := c.buildRequest(md, req, params) + connectReq, ctm, timeout, err := c.buildRequest(md, req, params) if err != nil { return nil, err } c.setSystemTags(ctm, c.addr, method) - ctx := c.vu.Context() + if timeout <= 0 { + // default timeout is 2 minutes + timeout = 2 * time.Minute + } + + ctx, cancel := context.WithTimeout(c.vu.Context(), timeout) + defer cancel() resp, err := c.callUnary(ctx, method, connectReq, ctm) if err != nil { @@ -239,7 +246,7 @@ func (c *client) AsyncInvoke(method string, req sobek.Value, params sobek.Value) return promise } - connectReq, ctm, err := c.buildRequest(md, req, params) + connectReq, ctm, timeout, err := c.buildRequest(md, req, params) if err != nil { reject(err) return promise @@ -248,7 +255,14 @@ func (c *client) AsyncInvoke(method string, req sobek.Value, params sobek.Value) callback := c.vu.RegisterCallback() - ctx := c.vu.Context() + if timeout <= 0 { + // default timeout is 2 minutes + timeout = 2 * time.Minute + } + + ctx, cancel := context.WithTimeout(c.vu.Context(), timeout) + defer cancel() + go func() { resp, err := c.callUnary(ctx, method, connectReq, ctm) @@ -325,12 +339,18 @@ func (c *client) Stream(method string, req, params sobek.Value) (*sobek.Object, connect.WithGRPCWeb(), ) - connectReq, ctm, err := c.buildRequest(md, req, params) + connectReq, ctm, timeout, err := c.buildRequest(md, req, params) if err != nil { return nil, err } c.setSystemTags(ctm, c.addr, method) + ctx := c.vu.Context() + var cancel context.CancelFunc + if timeout > 0 { + ctx, cancel = context.WithTimeout(ctx, timeout) + } + s := &stream{ vu: c.vu, metrics: c.metrics, @@ -339,9 +359,10 @@ func (c *client) Stream(method string, req, params sobek.Value) (*sobek.Object, md: md, eventListeners: newEventListeners(), tq: taskqueue.New(c.vu.RegisterCallback), + cancel: cancel, } - if err := s.begin(connectReq); err != nil { + if err := s.begin(ctx, connectReq); err != nil { return nil, err } @@ -423,6 +444,7 @@ func (c *client) parseConnectParams(params sobek.Value) (connectParams, error) { type callParams struct { metadata http.Header tagsAndMeta metrics.TagsAndMeta + timeout time.Duration } func (c *client) parseCallParams(params sobek.Value) (callParams, error) { @@ -431,6 +453,7 @@ func (c *client) parseCallParams(params sobek.Value) (callParams, error) { result := callParams{ metadata: http.Header{}, tagsAndMeta: c.vu.State().Tags.GetCurrentValues(), + timeout: 0, } if params != nil { @@ -460,30 +483,36 @@ func (c *client) parseCallParams(params sobek.Value) (callParams, error) { if err := common.ApplyCustomUserTags(rt, &result.tagsAndMeta, paramsObject.Get(k)); err != nil { return result, fmt.Errorf("metric tags: %w", err) } + case "timeout": + timeout, err := types.GetDurationValue(v.Export()) + if err != nil { + return result, fmt.Errorf("invalid timeout value: %w", err) + } + result.timeout = timeout } } } return result, nil } -func (c *client) buildRequest(md protoreflect.MethodDescriptor, req sobek.Value, params sobek.Value) (*connect.Request[dynamicpb.Message], *metrics.TagsAndMeta, error) { +func (c *client) buildRequest(md protoreflect.MethodDescriptor, req sobek.Value, params sobek.Value) (*connect.Request[dynamicpb.Message], *metrics.TagsAndMeta, time.Duration, error) { rt := c.vu.Runtime() b, err := req.ToObject(rt).MarshalJSON() if err != nil { - return nil, nil, err + return nil, nil, 0, err } reqdm := dynamicpb.NewMessage(md.Input()) err = protojson.Unmarshal(b, reqdm) if err != nil { - return nil, nil, err + return nil, nil, 0, err } r := connect.NewRequest(reqdm) p, err := c.parseCallParams(params) if err != nil { - return nil, nil, err + return nil, nil, 0, err } // headers @@ -491,7 +520,7 @@ func (c *client) buildRequest(md protoreflect.MethodDescriptor, req sobek.Value, r.Header()[k] = v } - return r, &p.tagsAndMeta, nil + return r, &p.tagsAndMeta, p.timeout, nil } func (c *client) setSystemTags(ctm *metrics.TagsAndMeta, addr *url.URL, method string) { diff --git a/grpcweb/stream.go b/grpcweb/stream.go index 968b683..6eb1643 100644 --- a/grpcweb/stream.go +++ b/grpcweb/stream.go @@ -1,6 +1,7 @@ package grpcweb import ( + "context" "errors" "fmt" "sync" @@ -74,6 +75,8 @@ type stream struct { tq *taskqueue.TaskQueue stream *connect.ServerStreamForClient[deferredMessage] + + cancel context.CancelFunc } func (s *stream) On(eventType string, handler func(sobek.Value) (sobek.Value, error)) { @@ -86,9 +89,7 @@ func (s *stream) On(eventType string, handler func(sobek.Value) (sobek.Value, er } } -func (s *stream) begin(req *connect.Request[dynamicpb.Message]) error { - ctx := s.vu.Context() - +func (s *stream) begin(ctx context.Context, req *connect.Request[dynamicpb.Message]) error { stream, err := s.client.CallServerStream(ctx, req) if err != nil { return err @@ -197,4 +198,8 @@ func (s *stream) queueClose() { }) return }) + + if s.cancel != nil { + s.cancel() + } }