1+ package main
2+
3+ import (
4+ "context"
5+ "crypto/rand"
6+ "errors"
7+ "fmt"
8+ "io"
9+ "log/slog"
10+ "net"
11+ "sync"
12+ "syscall"
13+ "time"
14+
15+ "github.com/rcrowley/go-metrics"
16+ "go.uber.org/ratelimit"
17+ "golang.org/x/sync/errgroup"
18+ )
19+
20+ const (
21+ flavorPersistent = "persistent"
22+ flavorEphemeral = "ephemeral"
23+ )
24+
25+ type ClientConfig struct {
26+ Protocol string
27+ ConnectFlavor string
28+ Connections int32
29+ ConnectRate int32
30+ Duration time.Duration
31+ MessageBytes int32
32+ MergeResultsEachHost bool
33+ }
34+
35+ type Client struct {
36+ config ClientConfig
37+ }
38+
39+ func NewClient (config ClientConfig ) * Client {
40+ return & Client {config : config }
41+ }
42+
43+ func waitLim (ctx context.Context , rl ratelimit.Limiter ) error {
44+ select {
45+ case <- ctx .Done ():
46+ return ctx .Err ()
47+ default :
48+ done := make (chan struct {})
49+ go func () {
50+ rl .Take ()
51+ close (done )
52+ }()
53+ select {
54+ case <- done :
55+ return nil
56+ case <- ctx .Done ():
57+ return ctx .Err ()
58+ }
59+ }
60+ }
61+
62+ func getOrRegisterTimer (key , addr string , mergeResultsEachHost bool ) metrics.Timer {
63+ if mergeResultsEachHost {
64+ return metrics .GetOrRegisterTimer (key , nil )
65+ }
66+ return metrics .GetOrRegisterTimer (key + "." + addr , nil )
67+ }
68+
69+ func unregisterTimer (key , addr string , mergeResultsEachHost bool ) {
70+ if mergeResultsEachHost {
71+ metrics .Unregister (key )
72+ return
73+ }
74+ metrics .Unregister (key + "." + addr )
75+ }
76+
77+ func measureTime (addr string , mergeResultsEachHost bool , f func () error ) error {
78+ ts := getOrRegisterTimer ("total.latency" , addr , mergeResultsEachHost )
79+ is := getOrRegisterTimer ("tick.latency" , addr , mergeResultsEachHost )
80+ start := time .Now ()
81+ if err := f (); err != nil {
82+ return err
83+ }
84+ elapsed := time .Since (start )
85+ ts .Update (elapsed )
86+ is .Update (elapsed )
87+ return nil
88+ }
89+
90+ func (c * Client ) ConnectToAddresses (ctx context.Context , addrs []string ) error {
91+ eg , ctx := errgroup .WithContext (ctx )
92+ for _ , addr := range addrs {
93+ addr := addr
94+ eg .Go (func () error {
95+ return c .connectAddr (ctx , addr )
96+ })
97+ }
98+
99+ if err := eg .Wait (); err != nil {
100+ return fmt .Errorf ("connection error: %w" , err )
101+ }
102+ return nil
103+ }
104+
105+ func (c * Client ) connectAddr (ctx context.Context , addr string ) error {
106+ switch c .config .Protocol {
107+ case "tcp" :
108+ switch c .config .ConnectFlavor {
109+ case flavorPersistent :
110+ return c .connectPersistent (ctx , addr )
111+ case flavorEphemeral :
112+ return c .connectEphemeral (ctx , addr )
113+ }
114+ case "udp" :
115+ return c .connectUDP (ctx , addr )
116+ }
117+ return fmt .Errorf ("invalid protocol or flavor combination" )
118+ }
119+
120+ func (c * Client ) connectPersistent (ctx context.Context , addrport string ) error {
121+ ctx , cancel := context .WithTimeout (ctx , c .config .Duration )
122+ defer cancel ()
123+
124+ bufTCPPool := sync.Pool {
125+ New : func () interface {} { return make ([]byte , c .config .MessageBytes ) },
126+ }
127+
128+ dialer := net.Dialer {
129+ Control : GetTCPControlWithFastOpen (),
130+ }
131+
132+ eg , ctx := errgroup .WithContext (ctx )
133+ for i := 0 ; i < int (c .config .Connections ); i ++ {
134+ eg .Go (func () error {
135+ conn , err := dialer .Dial ("tcp" , addrport )
136+ if err != nil {
137+ return fmt .Errorf ("dialing %q: %w" , addrport , err )
138+ }
139+ defer conn .Close ()
140+
141+ msgsTotal := int64 (c .config .ConnectRate ) * int64 (c .config .Duration .Seconds ())
142+ limiter := ratelimit .New (int (c .config .ConnectRate ))
143+
144+ for j := int64 (0 ); j < msgsTotal ; j ++ {
145+ if err := waitLim (ctx , limiter ); err != nil {
146+ if errors .Is (err , context .Canceled ) || errors .Is (err , context .DeadlineExceeded ) {
147+ return nil
148+ }
149+ continue
150+ }
151+
152+ if err := measureTime (addrport , c .config .MergeResultsEachHost , func () error {
153+ msg := bufTCPPool .Get ().([]byte )
154+ defer bufTCPPool .Put (msg )
155+
156+ if n , err := rand .Read (msg ); err != nil {
157+ return fmt .Errorf ("generating random data (length:%d): %w" , n , err )
158+ }
159+
160+ if _ , err := conn .Write (msg ); err != nil {
161+ return fmt .Errorf ("writing to connection: %w" , err )
162+ }
163+ if _ , err := conn .Read (msg ); err != nil {
164+ return fmt .Errorf ("reading from connection: %w" , err )
165+ }
166+ return nil
167+ }); err != nil {
168+ return err
169+ }
170+ }
171+ return nil
172+ })
173+ }
174+ return eg .Wait ()
175+ }
176+
177+ func (c * Client ) connectEphemeral (ctx context.Context , addrport string ) error {
178+ ctx , cancel := context .WithTimeout (ctx , c .config .Duration )
179+ defer cancel ()
180+
181+ bufTCPPool := sync.Pool {
182+ New : func () interface {} { return make ([]byte , c .config .MessageBytes ) },
183+ }
184+
185+ dialer := net.Dialer {
186+ Control : GetTCPControlWithFastOpen (),
187+ }
188+
189+ connTotal := int64 (c .config .ConnectRate ) * int64 (c .config .Duration .Seconds ())
190+ limiter := ratelimit .New (int (c .config .ConnectRate ))
191+
192+ eg , ctx := errgroup .WithContext (ctx )
193+ for i := int64 (0 ); i < connTotal ; i ++ {
194+ if err := waitLim (ctx , limiter ); err != nil {
195+ if errors .Is (err , context .Canceled ) || errors .Is (err , context .DeadlineExceeded ) {
196+ break
197+ }
198+ continue
199+ }
200+
201+ eg .Go (func () error {
202+ return measureTime (addrport , c .config .MergeResultsEachHost , func () error {
203+ conn , err := dialer .Dial ("tcp" , addrport )
204+ if err != nil {
205+ if errors .Is (err , syscall .ETIMEDOUT ) {
206+ slog .Warn ("connection timeout" , "addr" , addrport )
207+ return nil
208+ }
209+ return fmt .Errorf ("dialing %q: %w" , addrport , err )
210+ }
211+ defer conn .Close ()
212+
213+ if err := SetLinger (conn ); err != nil {
214+ return fmt .Errorf ("setting linger: %w" , err )
215+ }
216+ if err := SetQuickAck (conn ); err != nil {
217+ return fmt .Errorf ("setting quick ack: %w" , err )
218+ }
219+
220+ msg := bufTCPPool .Get ().([]byte )
221+ defer bufTCPPool .Put (msg )
222+
223+ if n , err := rand .Read (msg ); err != nil {
224+ return fmt .Errorf ("generating random data (length:%d): %w" , n , err )
225+ }
226+
227+ if _ , err := conn .Write (msg ); err != nil {
228+ if errors .Is (err , syscall .EINPROGRESS ) {
229+ slog .Warn ("write in progress" , "addr" , addrport )
230+ return nil
231+ }
232+ return fmt .Errorf ("writing to connection: %w" , err )
233+ }
234+
235+ if _ , err := conn .Read (msg ); err != nil {
236+ if errors .Is (err , syscall .ECONNRESET ) {
237+ slog .Warn ("connection reset" , "addr" , addrport )
238+ return nil
239+ }
240+ return fmt .Errorf ("reading from connection: %w" , err )
241+ }
242+
243+ return nil
244+ })
245+ })
246+ }
247+ return eg .Wait ()
248+ }
249+
250+ func (c * Client ) connectUDP (ctx context.Context , addrport string ) error {
251+ ctx , cancel := context .WithTimeout (ctx , c .config .Duration )
252+ defer cancel ()
253+
254+ connTotal := int64 (c .config .ConnectRate ) * int64 (c .config .Duration .Seconds ())
255+ limiter := ratelimit .New (int (c .config .ConnectRate ))
256+
257+ bufUDPPool := sync.Pool {
258+ New : func () interface {} { return make ([]byte , c .config .MessageBytes ) },
259+ }
260+
261+ eg , ctx := errgroup .WithContext (ctx )
262+ for i := int64 (0 ); i < connTotal ; i ++ {
263+ if err := waitLim (ctx , limiter ); err != nil {
264+ if errors .Is (err , context .Canceled ) || errors .Is (err , context .DeadlineExceeded ) {
265+ break
266+ }
267+ continue
268+ }
269+
270+ eg .Go (func () error {
271+ return measureTime (addrport , c .config .MergeResultsEachHost , func () error {
272+ conn , err := net .Dial ("udp4" , addrport )
273+ if err != nil {
274+ return fmt .Errorf ("dialing UDP %q: %w" , addrport , err )
275+ }
276+ defer conn .Close ()
277+
278+ msg := bufUDPPool .Get ().([]byte )
279+ defer bufUDPPool .Put (msg )
280+
281+ if n , err := rand .Read (msg ); err != nil {
282+ return fmt .Errorf ("generating random data (length:%d): %w" , n , err )
283+ }
284+
285+ if _ , err := conn .Write (msg ); err != nil {
286+ return fmt .Errorf ("writing to UDP connection: %w" , err )
287+ }
288+
289+ if _ , err := conn .Read (msg ); err != nil {
290+ return fmt .Errorf ("reading from UDP connection: %w" , err )
291+ }
292+
293+ return nil
294+ })
295+ })
296+ }
297+ return eg .Wait ()
298+ }
299+
300+ func toMicroseconds (n int64 ) int64 {
301+ return time .Duration (n ).Microseconds ()
302+ }
303+
304+ func toMicrosecondsf (n float64 ) int64 {
305+ return time .Duration (n ).Microseconds ()
306+ }
307+
308+ func printStatHeader (w io.Writer ) {
309+ fmt .Fprintf (w , "%-20s %-10s %-15s %-15s %-15s %-15s %-15s %-15s %-10s\n " ,
310+ "PEER" , "CNT" , "LAT_MAX(µs)" , "LAT_MIN(µs)" , "LAT_MEAN(µs)" ,
311+ "LAT_90p(µs)" , "LAT_95p(µs)" , "LAT_99p(µs)" , "RATE(/s)" )
312+ }
313+
314+ func printStatLine (w io.Writer , addr string , stat metrics.Timer ) {
315+ fmt .Fprintf (w , "%-20s %-10d %-15d %-15d %-15d %-15d %-15d %-15d %-10.2f\n " ,
316+ addr ,
317+ stat .Count (),
318+ toMicroseconds (stat .Max ()),
319+ toMicroseconds (stat .Min ()),
320+ toMicrosecondsf (stat .Mean ()),
321+ toMicrosecondsf (stat .Percentile (0.9 )),
322+ toMicrosecondsf (stat .Percentile (0.95 )),
323+ toMicrosecondsf (stat .Percentile (0.99 )),
324+ stat .RateMean (),
325+ )
326+ }
327+
328+ func runStatLinePrinter (ctx context.Context , w io.Writer , addr string , intervalStats time.Duration , mergeResultsEachHost bool ) {
329+ go func () {
330+ ticker := time .NewTicker (intervalStats )
331+ defer ticker .Stop ()
332+
333+ for {
334+ select {
335+ case <- ctx .Done ():
336+ return
337+ case <- ticker .C :
338+ ts := getOrRegisterTimer ("tick.latency" , addr , mergeResultsEachHost )
339+ printStatLine (w , addr , ts )
340+ unregisterTimer ("tick.latency" , addr , mergeResultsEachHost )
341+ }
342+ }
343+ }()
344+ }
345+
346+ func printReport (w io.Writer , addrs []string , mergeResultsEachHost bool ) {
347+ fmt .Fprintln (w , "--- A result during total execution time ---" )
348+ if mergeResultsEachHost {
349+ ts := getOrRegisterTimer ("total.latency" , "" , mergeResultsEachHost )
350+ printStatLine (w , fmt .Sprintf ("merged(%d hosts)" , len (addrs )), ts )
351+ return
352+ }
353+ for _ , addr := range addrs {
354+ ts := getOrRegisterTimer ("total.latency" , addr , mergeResultsEachHost )
355+ printStatLine (w , addr , ts )
356+ }
357+ }
0 commit comments