Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,6 @@

// Do sends the HTTP request and returns after response is received.
func (c *Client) Do(ctx context.Context, r *Request) (*Response, error) {
if env.interceptor != nil {
return env.interceptor(ctx, r, c.do)
}
return c.do(ctx, r)
}

func (c *Client) do(ctx context.Context, r *Request) (*Response, error) {
if r.opts.DumpRequestOut != nil {
reqDump, err := httputil.DumpRequestOut(r.Request, true)
if err != nil {
Expand All @@ -38,6 +31,13 @@
if ctx != nil {
r = r.WithContext(ctx)
}
if env.interceptor != nil {
return env.interceptor(ctx, r, c.do)
}
return c.do(ctx, r)

Check warning on line 37 in client.go

View check run for this annotation

Codecov / codecov/patch

client.go#L37

Added line #L37 was not covered by tests
}

func (c *Client) do(ctx context.Context, r *Request) (*Response, error) {
// If the returned error is nil, the Response will contain
// a non-nil Body which the user is expected to close.
resp, err := c.Client.Do(r.Request)
Expand Down
26 changes: 10 additions & 16 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,14 +141,14 @@ func do(method, url string, opts *Options, stats *Stats) (*Response, error) {
func request(method, url string, opts *Options) (*Response, error) {
stats := &Stats{}
// NOTE: get the body size from io.Reader. It is costy for large body.
buf := &bytes.Buffer{}
var body bytes.Buffer
if opts.Body != nil {
n, err := io.Copy(buf, opts.Body)
_, err := io.Copy(&body, opts.Body)
if err != nil {
return nil, err
}
stats.BodySize = int(n)
opts.Body = buf
opts.Body = &body
stats.BodySize = body.Len()
}
return do(method, url, opts, stats)
}
Expand All @@ -160,8 +160,8 @@ func requestData(method, url string, opts *Options) (*Response, error) {
var body *strings.Reader
if opts.Data != nil {
d := fmt.Sprintf("%v", opts.Data)
stats.BodySize = len(d)
body = strings.NewReader(d)
stats.BodySize = body.Len()
}
// TODO: judge content type
// opts.Headers["Content-Type"] = "application/x-www-form-urlencoded"
Expand All @@ -176,8 +176,8 @@ func requestForm(method, urlStr string, opts *Options) (*Response, error) {
var body *strings.Reader
if opts.Form != nil {
d := opts.Form.Encode()
stats.BodySize = len(d)
body = strings.NewReader(d)
stats.BodySize = body.Len()
}
opts.Headers.Set("Content-Type", "application/x-www-form-urlencoded")
opts.Body = body
Expand All @@ -193,10 +193,9 @@ func requestJSON(method, url string, opts *Options) (*Response, error) {
if err != nil {
return nil, err
}
stats.BodySize = len(d)
body = bytes.NewBuffer(d)
stats.BodySize = body.Len()
}

opts.Headers.Set("Content-Type", "application/json")
opts.Body = body
return do(method, url, opts, stats)
Expand All @@ -216,20 +215,15 @@ func requestFiles(method, url string, opts *Options) (*Response, error) {
if _, err := io.Copy(fileWriter, f); err != nil {
return nil, err
}
fi, err := f.Stat()
if err != nil {
return nil, err
}
stats.BodySize += int(fi.Size())
}
}

opts.Headers.Set("Content-Type", bodyWriter.FormDataContentType())
opts.Body = &body
// write EOF before sending
if err := bodyWriter.Close(); err != nil {
return nil, err
}
opts.Headers.Set("Content-Type", bodyWriter.FormDataContentType())
opts.Body = &body
stats.BodySize = body.Len()
return do(method, url, opts, stats)
}

Expand Down
172 changes: 160 additions & 12 deletions request_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,19 @@ import (
"net/url"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"testing"
"time"

"github.com/stretchr/testify/require"
)

func init() {
WithInterceptor(logInterceptor, metricInterceptor, traceInterceptor, bodySizeChecker)
}

func logInterceptor(ctx context.Context, r *Request, do Do) (*Response, error) {
log.Printf("method: %s", r.Method)
return do(ctx, r)
Expand Down Expand Up @@ -48,10 +54,6 @@ func traceInterceptor(ctx context.Context, r *Request, do Do) (*Response, error)
return do(ctx, r)
}

func init() {
WithInterceptor(logInterceptor, metricInterceptor, traceInterceptor)
}

type testRequest struct {
Headers http.Header
Params url.Values
Expand Down Expand Up @@ -546,7 +548,7 @@ func TestPostFiles(t *testing.T) {
defer fh2.Close()

type args struct {
urlStr string
url string
options []Option
}
tests := []struct {
Expand All @@ -558,7 +560,7 @@ func TestPostFiles(t *testing.T) {
{
name: "upload file test case 1",
args: args{
urlStr: testServer.URL,
url: testServer.URL,
options: []Option{
Files(map[string]*os.File{
"file1": fh1,
Expand All @@ -572,7 +574,7 @@ func TestPostFiles(t *testing.T) {
{
name: "upload file test case 2",
args: args{
urlStr: "http://127.0.0.1:11111/unknown",
url: "http://127.0.0.1:11111/unknown",
options: []Option{
Files(map[string]*os.File{
"file1": fh1,
Expand All @@ -586,7 +588,8 @@ func TestPostFiles(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := Post(tt.args.urlStr, tt.args.options...)
var dump string
resp, err := Post(tt.args.url, append(tt.args.options, Dump(&dump, nil))...)
if (err != nil) != tt.wantErr {
t.Errorf("Post() error = %v, wantErr %v", err, tt.wantErr)
return
Expand Down Expand Up @@ -614,7 +617,7 @@ func TestPatch(t *testing.T) {
}))
defer testServer.Close()
type args struct {
urlStr string
url string
options []Option
}
tests := []struct {
Expand All @@ -626,7 +629,7 @@ func TestPatch(t *testing.T) {
{
name: "patch test case 1",
args: args{
urlStr: testServer.URL,
url: testServer.URL,
options: []Option{
JSON(map[string]any{
"status": 0,
Expand All @@ -640,7 +643,7 @@ func TestPatch(t *testing.T) {
{
name: "patch test case 2",
args: args{
urlStr: "http://127.0.0.1:11111/unknown",
url: "http://127.0.0.1:11111/unknown",
options: []Option{
JSON(map[string]any{
"status": 0,
Expand All @@ -654,7 +657,8 @@ func TestPatch(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resp, err := Patch(tt.args.urlStr, tt.args.options...)
var dump string
resp, err := Patch(tt.args.url, append(tt.args.options, Dump(&dump, nil))...)
if (err != nil) != tt.wantErr {
t.Errorf("Patch() error = %v, wantErr %v", err, tt.wantErr)
return
Expand All @@ -665,3 +669,147 @@ func TestPatch(t *testing.T) {
})
}
}

type ctxKey struct{}
type ctxValue struct {
t *testing.T
dump string
}

func newContext(ctx context.Context, value *ctxValue) context.Context {
return context.WithValue(ctx, ctxKey{}, value)
}

func fromContext(ctx context.Context) *ctxValue {
r, ok := ctx.Value(ctxKey{}).(*ctxValue)
if !ok {
return nil
}
return r
}

var re = regexp.MustCompile(`Content-Length: (\d+)`)

func getRequestContentLength(reqDump string) (int, error) {
match := re.FindStringSubmatch(reqDump)
if len(match) < 2 {
return 0, fmt.Errorf("Content-Length not found in request dump")
}
return strconv.Atoi(match[1])
}

func bodySizeChecker(ctx context.Context, r *Request, do Do) (*Response, error) {
if v := fromContext(ctx); v != nil {
contentLength, err := getRequestContentLength(v.dump)
require.NoError(v.t, err)
require.Equalf(v.t, contentLength, r.Stats.BodySize, "content length mismatch, got %v", r.Stats.BodySize)
}
return do(ctx, r)
}

func TestBodySize(t *testing.T) {
filename1 := "./testdata/file1.txt"
filename2 := "./testdata/file2.txt"
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))
defer testServer.Close()

fh1, err := os.Open(filename1)
if err != nil {
t.Errorf("open file: %s failed: %+v", filename1, err)
return
}
defer fh1.Close()

fh2, err := os.Open(filename2)
if err != nil {
t.Errorf("open file: %s failed: %+v", filename2, err)
return
}
defer fh2.Close()

type args struct {
url string
options []Option
}
tests := []struct {
name string
args args
want *Response
wantErr bool
}{
{
name: "body",
args: args{
url: testServer.URL,
options: []Option{
Body(strings.NewReader("test1")),
},
},
wantErr: false,
},
{
name: "data",
args: args{
url: testServer.URL,
options: []Option{
Data("test1"),
},
},
wantErr: false,
},
{
name: "form",
args: args{
url: testServer.URL,
options: []Option{
Form(map[string]string{"form1": "value1"}),
Form(url.Values{"form2": []string{"value2", "value2-2"}}),
FormPairs("form1", "value1-2"),
FormPairs("form2", "value2-3", "form2", "value2-4"),
},
},
wantErr: false,
},
{
name: "json",
args: args{
url: testServer.URL,
options: []Option{
ParamPairs("param1", "value1"),
ParamPairs("param2", "value2"),
HeaderPairs("header1", "value1"),
HeaderPairs("header2", "value2"),
JSON(&EchoRequest{ID: 1, Name: "Hello"}),
},
},
wantErr: false,
},
{
name: "file",
args: args{
url: testServer.URL,
options: []Option{
Files(map[string]*os.File{
"file1": fh1,
"file2": fh2,
}),
Timeout(120 * time.Second),
},
},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
v := &ctxValue{t: t, dump: ""}
resp, err := Post(tt.args.url, append(tt.args.options, Context(newContext(context.Background(), v)), Dump(&v.dump, nil))...)
if (err != nil) != tt.wantErr {
t.Errorf("Post() error = %v, wantErr %v", err, tt.wantErr)
return
}
if resp != nil {
t.Logf("resp: %s", resp.Text())
}
})
}
}