|
1 |
| -// Copyright 2020-2022 The NATS Authors |
| 1 | +// Copyright 2020-2023 The NATS Authors |
2 | 2 | // Licensed under the Apache License, Version 2.0 (the "License");
|
3 | 3 | // you may not use this file except in compliance with the License.
|
4 | 4 | // You may obtain a copy of the License at
|
@@ -2673,16 +2673,6 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) {
|
2673 | 2673 | return nil, err
|
2674 | 2674 | }
|
2675 | 2675 |
|
2676 |
| - // Use the deadline of the context to base the expire times. |
2677 |
| - deadline, _ := ctx.Deadline() |
2678 |
| - ttl = time.Until(deadline) |
2679 |
| - checkCtxErr := func(err error) error { |
2680 |
| - if o.ctx == nil && err == context.DeadlineExceeded { |
2681 |
| - return ErrTimeout |
2682 |
| - } |
2683 |
| - return err |
2684 |
| - } |
2685 |
| - |
2686 | 2676 | var (
|
2687 | 2677 | msgs = make([]*Msg, 0, batch)
|
2688 | 2678 | msg *Msg
|
@@ -2716,7 +2706,7 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) {
|
2716 | 2706 | sendReq := func() error {
|
2717 | 2707 | // The current deadline for the context will be used
|
2718 | 2708 | // to set the expires TTL for a fetch request.
|
2719 |
| - deadline, _ = ctx.Deadline() |
| 2709 | + deadline, _ := ctx.Deadline() |
2720 | 2710 | ttl = time.Until(deadline)
|
2721 | 2711 |
|
2722 | 2712 | // Check if context has already been canceled or expired.
|
@@ -2766,11 +2756,235 @@ func (sub *Subscription) Fetch(batch int, opts ...PullOpt) ([]*Msg, error) {
|
2766 | 2756 | }
|
2767 | 2757 | // If there is at least a message added to msgs, then need to return OK and no error
|
2768 | 2758 | if err != nil && len(msgs) == 0 {
|
2769 |
| - return nil, checkCtxErr(err) |
| 2759 | + return nil, o.checkCtxErr(err) |
2770 | 2760 | }
|
2771 | 2761 | return msgs, nil
|
2772 | 2762 | }
|
2773 | 2763 |
|
| 2764 | +// MessageBatch provides methods to retrieve messages consumed using [Subscribe.FetchBatch]. |
| 2765 | +type MessageBatch interface { |
| 2766 | + // Messages returns a channel on which messages will be published. |
| 2767 | + Messages() <-chan *Msg |
| 2768 | + |
| 2769 | + // Error returns an error encountered when fetching messages. |
| 2770 | + Error() error |
| 2771 | + |
| 2772 | + // Done signals end of execution. |
| 2773 | + Done() <-chan struct{} |
| 2774 | +} |
| 2775 | + |
| 2776 | +type messageBatch struct { |
| 2777 | + msgs chan *Msg |
| 2778 | + err error |
| 2779 | + done chan struct{} |
| 2780 | +} |
| 2781 | + |
| 2782 | +func (mb *messageBatch) Messages() <-chan *Msg { |
| 2783 | + return mb.msgs |
| 2784 | +} |
| 2785 | + |
| 2786 | +func (mb *messageBatch) Error() error { |
| 2787 | + return mb.err |
| 2788 | +} |
| 2789 | + |
| 2790 | +func (mb *messageBatch) Done() <-chan struct{} { |
| 2791 | + return mb.done |
| 2792 | +} |
| 2793 | + |
| 2794 | +// FetchBatch pulls a batch of messages from a stream for a pull consumer. |
| 2795 | +// Unlike [Subscription.Fetch], it is non blocking and returns [MessageBatch], |
| 2796 | +// allowing to retrieve incoming messages from a channel. |
| 2797 | +// The returned channel is always closed after all messages for a batch have been |
| 2798 | +// delivered by the server - it is safe to iterate over it using range. |
| 2799 | +// |
| 2800 | +// To avoid using default JetStream timeout as fetch expiry time, use [nats.MaxWait] |
| 2801 | +// or [nats.Context] (with deadline set). |
| 2802 | +// |
| 2803 | +// This method will not return error in case of pull request expiry (even if there are no messages). |
| 2804 | +// Any other error encountered when receiving messages will cause FetchBatch to stop receiving new messages. |
| 2805 | +func (sub *Subscription) FetchBatch(batch int, opts ...PullOpt) (MessageBatch, error) { |
| 2806 | + if sub == nil { |
| 2807 | + return nil, ErrBadSubscription |
| 2808 | + } |
| 2809 | + if batch < 1 { |
| 2810 | + return nil, ErrInvalidArg |
| 2811 | + } |
| 2812 | + |
| 2813 | + var o pullOpts |
| 2814 | + for _, opt := range opts { |
| 2815 | + if err := opt.configurePull(&o); err != nil { |
| 2816 | + return nil, err |
| 2817 | + } |
| 2818 | + } |
| 2819 | + if o.ctx != nil && o.ttl != 0 { |
| 2820 | + return nil, ErrContextAndTimeout |
| 2821 | + } |
| 2822 | + sub.mu.Lock() |
| 2823 | + jsi := sub.jsi |
| 2824 | + // Reject if this is not a pull subscription. Note that sub.typ is SyncSubscription, |
| 2825 | + // so check for jsi.pull boolean instead. |
| 2826 | + if jsi == nil || !jsi.pull { |
| 2827 | + sub.mu.Unlock() |
| 2828 | + return nil, ErrTypeSubscription |
| 2829 | + } |
| 2830 | + |
| 2831 | + nc := sub.conn |
| 2832 | + nms := sub.jsi.nms |
| 2833 | + rply := sub.jsi.deliver |
| 2834 | + js := sub.jsi.js |
| 2835 | + pmc := len(sub.mch) > 0 |
| 2836 | + |
| 2837 | + // All fetch requests have an expiration, in case of no explicit expiration |
| 2838 | + // then the default timeout of the JetStream context is used. |
| 2839 | + ttl := o.ttl |
| 2840 | + if ttl == 0 { |
| 2841 | + ttl = js.opts.wait |
| 2842 | + } |
| 2843 | + sub.mu.Unlock() |
| 2844 | + |
| 2845 | + // Use the given context or setup a default one for the span |
| 2846 | + // of the pull batch request. |
| 2847 | + var ( |
| 2848 | + ctx = o.ctx |
| 2849 | + cancel context.CancelFunc |
| 2850 | + cancelContext = true |
| 2851 | + ) |
| 2852 | + if ctx == nil { |
| 2853 | + ctx, cancel = context.WithTimeout(context.Background(), ttl) |
| 2854 | + } else if _, hasDeadline := ctx.Deadline(); !hasDeadline { |
| 2855 | + // Prevent from passing the background context which will just block |
| 2856 | + // and cannot be canceled either. |
| 2857 | + if octx, ok := ctx.(ContextOpt); ok && octx.Context == context.Background() { |
| 2858 | + return nil, ErrNoDeadlineContext |
| 2859 | + } |
| 2860 | + |
| 2861 | + // If the context did not have a deadline, then create a new child context |
| 2862 | + // that will use the default timeout from the JS context. |
| 2863 | + ctx, cancel = context.WithTimeout(ctx, ttl) |
| 2864 | + } |
| 2865 | + defer func() { |
| 2866 | + // only cancel the context here if we are sure the fetching goroutine has not been started yet |
| 2867 | + if cancel != nil && cancelContext { |
| 2868 | + cancel() |
| 2869 | + } |
| 2870 | + }() |
| 2871 | + |
| 2872 | + // Check if context not done already before making the request. |
| 2873 | + select { |
| 2874 | + case <-ctx.Done(): |
| 2875 | + if o.ctx != nil { // Timeout or Cancel triggered by context object option |
| 2876 | + return nil, ctx.Err() |
| 2877 | + } else { // Timeout triggered by timeout option |
| 2878 | + return nil, ErrTimeout |
| 2879 | + } |
| 2880 | + default: |
| 2881 | + } |
| 2882 | + |
| 2883 | + result := &messageBatch{ |
| 2884 | + msgs: make(chan *Msg, batch), |
| 2885 | + done: make(chan struct{}, 1), |
| 2886 | + } |
| 2887 | + var msg *Msg |
| 2888 | + for pmc && len(result.msgs) < batch { |
| 2889 | + // Check next msg with booleans that say that this is an internal call |
| 2890 | + // for a pull subscribe (so don't reject it) and don't wait if there |
| 2891 | + // are no messages. |
| 2892 | + msg, err := sub.nextMsgWithContext(ctx, true, false) |
| 2893 | + if err != nil { |
| 2894 | + if err == errNoMessages { |
| 2895 | + err = nil |
| 2896 | + } |
| 2897 | + result.err = err |
| 2898 | + break |
| 2899 | + } |
| 2900 | + // Check msg but just to determine if this is a user message |
| 2901 | + // or status message, however, we don't care about values of status |
| 2902 | + // messages at this point in the Fetch() call, so checkMsg can't |
| 2903 | + // return an error. |
| 2904 | + if usrMsg, _ := checkMsg(msg, false, false); usrMsg { |
| 2905 | + result.msgs <- msg |
| 2906 | + } |
| 2907 | + } |
| 2908 | + if len(result.msgs) == batch || result.err != nil { |
| 2909 | + close(result.msgs) |
| 2910 | + result.done <- struct{}{} |
| 2911 | + return result, nil |
| 2912 | + } |
| 2913 | + |
| 2914 | + deadline, _ := ctx.Deadline() |
| 2915 | + ttl = time.Until(deadline) |
| 2916 | + |
| 2917 | + // Make our request expiration a bit shorter than the current timeout. |
| 2918 | + expires := ttl |
| 2919 | + if ttl >= 20*time.Millisecond { |
| 2920 | + expires = ttl - 10*time.Millisecond |
| 2921 | + } |
| 2922 | + |
| 2923 | + requestBatch := batch - len(result.msgs) |
| 2924 | + req := nextRequest{ |
| 2925 | + Expires: expires, |
| 2926 | + Batch: requestBatch, |
| 2927 | + MaxBytes: o.maxBytes, |
| 2928 | + } |
| 2929 | + reqJSON, err := json.Marshal(req) |
| 2930 | + if err != nil { |
| 2931 | + close(result.msgs) |
| 2932 | + result.done <- struct{}{} |
| 2933 | + result.err = err |
| 2934 | + return result, nil |
| 2935 | + } |
| 2936 | + if err := nc.PublishRequest(nms, rply, reqJSON); err != nil { |
| 2937 | + if len(result.msgs) == 0 { |
| 2938 | + return nil, err |
| 2939 | + } |
| 2940 | + close(result.msgs) |
| 2941 | + result.done <- struct{}{} |
| 2942 | + result.err = err |
| 2943 | + return result, nil |
| 2944 | + } |
| 2945 | + cancelContext = false |
| 2946 | + go func() { |
| 2947 | + if cancel != nil { |
| 2948 | + defer cancel() |
| 2949 | + } |
| 2950 | + var requestMsgs int |
| 2951 | + for requestMsgs < requestBatch { |
| 2952 | + // Ask for next message and wait if there are no messages |
| 2953 | + msg, err = sub.nextMsgWithContext(ctx, true, true) |
| 2954 | + if err != nil { |
| 2955 | + break |
| 2956 | + } |
| 2957 | + var usrMsg bool |
| 2958 | + |
| 2959 | + usrMsg, err = checkMsg(msg, true, false) |
| 2960 | + if err != nil { |
| 2961 | + if err == ErrTimeout { |
| 2962 | + err = nil |
| 2963 | + } |
| 2964 | + break |
| 2965 | + } |
| 2966 | + if usrMsg { |
| 2967 | + result.msgs <- msg |
| 2968 | + requestMsgs++ |
| 2969 | + } |
| 2970 | + } |
| 2971 | + if err != nil { |
| 2972 | + result.err = o.checkCtxErr(err) |
| 2973 | + } |
| 2974 | + close(result.msgs) |
| 2975 | + result.done <- struct{}{} |
| 2976 | + }() |
| 2977 | + return result, nil |
| 2978 | +} |
| 2979 | + |
| 2980 | +// checkCtxErr is used to determine whether ErrTimeout should be returned in case of context timeout |
| 2981 | +func (o *pullOpts) checkCtxErr(err error) error { |
| 2982 | + if o.ctx == nil && err == context.DeadlineExceeded { |
| 2983 | + return ErrTimeout |
| 2984 | + } |
| 2985 | + return err |
| 2986 | +} |
| 2987 | + |
2774 | 2988 | func (js *js) getConsumerInfo(stream, consumer string) (*ConsumerInfo, error) {
|
2775 | 2989 | ctx, cancel := context.WithTimeout(context.Background(), js.opts.wait)
|
2776 | 2990 | defer cancel()
|
|
0 commit comments