Skip to content

Commit dde7e9c

Browse files
committed
refactor: separate CLI flag and TCP/UDP logics
- Added a new `client.go` file to handle client connections and configurations. - Introduced a `server.go` file to manage TCP and UDP server functionalities. - Refactored connection handling in tests to utilize the new client structure. - Enhanced error handling and logging for server operations. - Updated tests to validate the new server-client interactions.
1 parent e5be955 commit dde7e9c

File tree

9 files changed

+810
-612
lines changed

9 files changed

+810
-612
lines changed

CLAUDE.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,4 +61,8 @@ The project includes test files (cmd/serve_test.go). When adding new functionali
6161
- File descriptor limits are automatically raised via the `limit` package
6262
- Server gracefully handles signals (SIGINT, SIGTERM) for clean shutdown
6363
- Error handling includes specific logic for network timeouts and connection resets
64-
- Build uses Go modules with go 1.24.4
64+
- Build uses Go modules with go 1.24.4
65+
66+
## Additional Guidance
67+
68+
- Place all go files into top-level directory

client.go

Lines changed: 357 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,357 @@
1+
package main
2+
3+
import (
4+
"context"
5+
"crypto/rand"
6+
"errors"
7+
"fmt"
8+
"io"
9+
"log/slog"
10+
"net"
11+
"sync"
12+
"syscall"
13+
"time"
14+
15+
"github.com/rcrowley/go-metrics"
16+
"go.uber.org/ratelimit"
17+
"golang.org/x/sync/errgroup"
18+
)
19+
20+
const (
21+
flavorPersistent = "persistent"
22+
flavorEphemeral = "ephemeral"
23+
)
24+
25+
type ClientConfig struct {
26+
Protocol string
27+
ConnectFlavor string
28+
Connections int32
29+
ConnectRate int32
30+
Duration time.Duration
31+
MessageBytes int32
32+
MergeResultsEachHost bool
33+
}
34+
35+
type Client struct {
36+
config ClientConfig
37+
}
38+
39+
func NewClient(config ClientConfig) *Client {
40+
return &Client{config: config}
41+
}
42+
43+
func waitLim(ctx context.Context, rl ratelimit.Limiter) error {
44+
select {
45+
case <-ctx.Done():
46+
return ctx.Err()
47+
default:
48+
done := make(chan struct{})
49+
go func() {
50+
rl.Take()
51+
close(done)
52+
}()
53+
select {
54+
case <-done:
55+
return nil
56+
case <-ctx.Done():
57+
return ctx.Err()
58+
}
59+
}
60+
}
61+
62+
func getOrRegisterTimer(key, addr string, mergeResultsEachHost bool) metrics.Timer {
63+
if mergeResultsEachHost {
64+
return metrics.GetOrRegisterTimer(key, nil)
65+
}
66+
return metrics.GetOrRegisterTimer(key+"."+addr, nil)
67+
}
68+
69+
func unregisterTimer(key, addr string, mergeResultsEachHost bool) {
70+
if mergeResultsEachHost {
71+
metrics.Unregister(key)
72+
return
73+
}
74+
metrics.Unregister(key + "." + addr)
75+
}
76+
77+
func measureTime(addr string, mergeResultsEachHost bool, f func() error) error {
78+
ts := getOrRegisterTimer("total.latency", addr, mergeResultsEachHost)
79+
is := getOrRegisterTimer("tick.latency", addr, mergeResultsEachHost)
80+
start := time.Now()
81+
if err := f(); err != nil {
82+
return err
83+
}
84+
elapsed := time.Since(start)
85+
ts.Update(elapsed)
86+
is.Update(elapsed)
87+
return nil
88+
}
89+
90+
func (c *Client) ConnectToAddresses(ctx context.Context, addrs []string) error {
91+
eg, ctx := errgroup.WithContext(ctx)
92+
for _, addr := range addrs {
93+
addr := addr
94+
eg.Go(func() error {
95+
return c.connectAddr(ctx, addr)
96+
})
97+
}
98+
99+
if err := eg.Wait(); err != nil {
100+
return fmt.Errorf("connection error: %w", err)
101+
}
102+
return nil
103+
}
104+
105+
func (c *Client) connectAddr(ctx context.Context, addr string) error {
106+
switch c.config.Protocol {
107+
case "tcp":
108+
switch c.config.ConnectFlavor {
109+
case flavorPersistent:
110+
return c.connectPersistent(ctx, addr)
111+
case flavorEphemeral:
112+
return c.connectEphemeral(ctx, addr)
113+
}
114+
case "udp":
115+
return c.connectUDP(ctx, addr)
116+
}
117+
return fmt.Errorf("invalid protocol or flavor combination")
118+
}
119+
120+
func (c *Client) connectPersistent(ctx context.Context, addrport string) error {
121+
ctx, cancel := context.WithTimeout(ctx, c.config.Duration)
122+
defer cancel()
123+
124+
bufTCPPool := sync.Pool{
125+
New: func() interface{} { return make([]byte, c.config.MessageBytes) },
126+
}
127+
128+
dialer := net.Dialer{
129+
Control: GetTCPControlWithFastOpen(),
130+
}
131+
132+
eg, ctx := errgroup.WithContext(ctx)
133+
for i := 0; i < int(c.config.Connections); i++ {
134+
eg.Go(func() error {
135+
conn, err := dialer.Dial("tcp", addrport)
136+
if err != nil {
137+
return fmt.Errorf("dialing %q: %w", addrport, err)
138+
}
139+
defer conn.Close()
140+
141+
msgsTotal := int64(c.config.ConnectRate) * int64(c.config.Duration.Seconds())
142+
limiter := ratelimit.New(int(c.config.ConnectRate))
143+
144+
for j := int64(0); j < msgsTotal; j++ {
145+
if err := waitLim(ctx, limiter); err != nil {
146+
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
147+
return nil
148+
}
149+
continue
150+
}
151+
152+
if err := measureTime(addrport, c.config.MergeResultsEachHost, func() error {
153+
msg := bufTCPPool.Get().([]byte)
154+
defer bufTCPPool.Put(msg)
155+
156+
if n, err := rand.Read(msg); err != nil {
157+
return fmt.Errorf("generating random data (length:%d): %w", n, err)
158+
}
159+
160+
if _, err := conn.Write(msg); err != nil {
161+
return fmt.Errorf("writing to connection: %w", err)
162+
}
163+
if _, err := conn.Read(msg); err != nil {
164+
return fmt.Errorf("reading from connection: %w", err)
165+
}
166+
return nil
167+
}); err != nil {
168+
return err
169+
}
170+
}
171+
return nil
172+
})
173+
}
174+
return eg.Wait()
175+
}
176+
177+
func (c *Client) connectEphemeral(ctx context.Context, addrport string) error {
178+
ctx, cancel := context.WithTimeout(ctx, c.config.Duration)
179+
defer cancel()
180+
181+
bufTCPPool := sync.Pool{
182+
New: func() interface{} { return make([]byte, c.config.MessageBytes) },
183+
}
184+
185+
dialer := net.Dialer{
186+
Control: GetTCPControlWithFastOpen(),
187+
}
188+
189+
connTotal := int64(c.config.ConnectRate) * int64(c.config.Duration.Seconds())
190+
limiter := ratelimit.New(int(c.config.ConnectRate))
191+
192+
eg, ctx := errgroup.WithContext(ctx)
193+
for i := int64(0); i < connTotal; i++ {
194+
if err := waitLim(ctx, limiter); err != nil {
195+
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
196+
break
197+
}
198+
continue
199+
}
200+
201+
eg.Go(func() error {
202+
return measureTime(addrport, c.config.MergeResultsEachHost, func() error {
203+
conn, err := dialer.Dial("tcp", addrport)
204+
if err != nil {
205+
if errors.Is(err, syscall.ETIMEDOUT) {
206+
slog.Warn("connection timeout", "addr", addrport)
207+
return nil
208+
}
209+
return fmt.Errorf("dialing %q: %w", addrport, err)
210+
}
211+
defer conn.Close()
212+
213+
if err := SetLinger(conn); err != nil {
214+
return fmt.Errorf("setting linger: %w", err)
215+
}
216+
if err := SetQuickAck(conn); err != nil {
217+
return fmt.Errorf("setting quick ack: %w", err)
218+
}
219+
220+
msg := bufTCPPool.Get().([]byte)
221+
defer bufTCPPool.Put(msg)
222+
223+
if n, err := rand.Read(msg); err != nil {
224+
return fmt.Errorf("generating random data (length:%d): %w", n, err)
225+
}
226+
227+
if _, err := conn.Write(msg); err != nil {
228+
if errors.Is(err, syscall.EINPROGRESS) {
229+
slog.Warn("write in progress", "addr", addrport)
230+
return nil
231+
}
232+
return fmt.Errorf("writing to connection: %w", err)
233+
}
234+
235+
if _, err := conn.Read(msg); err != nil {
236+
if errors.Is(err, syscall.ECONNRESET) {
237+
slog.Warn("connection reset", "addr", addrport)
238+
return nil
239+
}
240+
return fmt.Errorf("reading from connection: %w", err)
241+
}
242+
243+
return nil
244+
})
245+
})
246+
}
247+
return eg.Wait()
248+
}
249+
250+
func (c *Client) connectUDP(ctx context.Context, addrport string) error {
251+
ctx, cancel := context.WithTimeout(ctx, c.config.Duration)
252+
defer cancel()
253+
254+
connTotal := int64(c.config.ConnectRate) * int64(c.config.Duration.Seconds())
255+
limiter := ratelimit.New(int(c.config.ConnectRate))
256+
257+
bufUDPPool := sync.Pool{
258+
New: func() interface{} { return make([]byte, c.config.MessageBytes) },
259+
}
260+
261+
eg, ctx := errgroup.WithContext(ctx)
262+
for i := int64(0); i < connTotal; i++ {
263+
if err := waitLim(ctx, limiter); err != nil {
264+
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
265+
break
266+
}
267+
continue
268+
}
269+
270+
eg.Go(func() error {
271+
return measureTime(addrport, c.config.MergeResultsEachHost, func() error {
272+
conn, err := net.Dial("udp4", addrport)
273+
if err != nil {
274+
return fmt.Errorf("dialing UDP %q: %w", addrport, err)
275+
}
276+
defer conn.Close()
277+
278+
msg := bufUDPPool.Get().([]byte)
279+
defer bufUDPPool.Put(msg)
280+
281+
if n, err := rand.Read(msg); err != nil {
282+
return fmt.Errorf("generating random data (length:%d): %w", n, err)
283+
}
284+
285+
if _, err := conn.Write(msg); err != nil {
286+
return fmt.Errorf("writing to UDP connection: %w", err)
287+
}
288+
289+
if _, err := conn.Read(msg); err != nil {
290+
return fmt.Errorf("reading from UDP connection: %w", err)
291+
}
292+
293+
return nil
294+
})
295+
})
296+
}
297+
return eg.Wait()
298+
}
299+
300+
func toMicroseconds(n int64) int64 {
301+
return time.Duration(n).Microseconds()
302+
}
303+
304+
func toMicrosecondsf(n float64) int64 {
305+
return time.Duration(n).Microseconds()
306+
}
307+
308+
func printStatHeader(w io.Writer) {
309+
fmt.Fprintf(w, "%-20s %-10s %-15s %-15s %-15s %-15s %-15s %-15s %-10s\n",
310+
"PEER", "CNT", "LAT_MAX(µs)", "LAT_MIN(µs)", "LAT_MEAN(µs)",
311+
"LAT_90p(µs)", "LAT_95p(µs)", "LAT_99p(µs)", "RATE(/s)")
312+
}
313+
314+
func printStatLine(w io.Writer, addr string, stat metrics.Timer) {
315+
fmt.Fprintf(w, "%-20s %-10d %-15d %-15d %-15d %-15d %-15d %-15d %-10.2f\n",
316+
addr,
317+
stat.Count(),
318+
toMicroseconds(stat.Max()),
319+
toMicroseconds(stat.Min()),
320+
toMicrosecondsf(stat.Mean()),
321+
toMicrosecondsf(stat.Percentile(0.9)),
322+
toMicrosecondsf(stat.Percentile(0.95)),
323+
toMicrosecondsf(stat.Percentile(0.99)),
324+
stat.RateMean(),
325+
)
326+
}
327+
328+
func runStatLinePrinter(ctx context.Context, w io.Writer, addr string, intervalStats time.Duration, mergeResultsEachHost bool) {
329+
go func() {
330+
ticker := time.NewTicker(intervalStats)
331+
defer ticker.Stop()
332+
333+
for {
334+
select {
335+
case <-ctx.Done():
336+
return
337+
case <-ticker.C:
338+
ts := getOrRegisterTimer("tick.latency", addr, mergeResultsEachHost)
339+
printStatLine(w, addr, ts)
340+
unregisterTimer("tick.latency", addr, mergeResultsEachHost)
341+
}
342+
}
343+
}()
344+
}
345+
346+
func printReport(w io.Writer, addrs []string, mergeResultsEachHost bool) {
347+
fmt.Fprintln(w, "--- A result during total execution time ---")
348+
if mergeResultsEachHost {
349+
ts := getOrRegisterTimer("total.latency", "", mergeResultsEachHost)
350+
printStatLine(w, fmt.Sprintf("merged(%d hosts)", len(addrs)), ts)
351+
return
352+
}
353+
for _, addr := range addrs {
354+
ts := getOrRegisterTimer("total.latency", addr, mergeResultsEachHost)
355+
printStatLine(w, addr, ts)
356+
}
357+
}

0 commit comments

Comments
 (0)