Skip to content

Commit 63ae567

Browse files
committed
[minor] added support for Join
1 parent e08bc56 commit 63ae567

File tree

5 files changed

+117
-27
lines changed

5 files changed

+117
-27
lines changed

errors_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,14 @@ func TestSetDefaultType(t *testing.T) {
134134
wantErrType: TypeInputBody,
135135
},
136136
}
137+
137138
for _, tt := range tests {
138139
t.Run(tt.name, func(t *testing.T) {
140+
before := defaultErrType
139141
SetDefaultType(tt.args.e)
140142
err := New(tt.args.message)
143+
// resetting to previous value to stop messing with the entire package
144+
SetDefaultType(before)
141145
if err.Type() != tt.wantErrType {
142146
t.Errorf(
143147
"New() = got type '%d', expected '%d",

helper.go

Lines changed: 61 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -292,23 +292,38 @@ func DownstreamDependencyTimedoutErrf(original error, format string, args ...int
292292
}
293293

294294
// HTTPStatusCodeMessage returns the appropriate HTTP status code, message, boolean for the error
295-
// the boolean value is true if the error was of type *Error, false otherwise
295+
// the boolean value is true if the error was of type *Error, false otherwise.
296296
func HTTPStatusCodeMessage(err error) (int, string, bool) {
297-
derr, _ := err.(*Error)
298-
if derr != nil {
299-
return derr.HTTPStatusCode(), derr.Message(), true
297+
code, isErr := HTTPStatusCode(err)
298+
msg, isErrMsg := Message(err)
299+
if msg == "" {
300+
msg = err.Error()
300301
}
301-
302-
return http.StatusInternalServerError, err.Error(), false
302+
return code, msg, isErr && isErrMsg
303303
}
304304

305305
// HTTPStatusCode returns appropriate HTTP response status code based on type of the error. The boolean
306306
// is 'true' if the provided error is of type *Err
307+
// In case of joined errors, it'll return the status code of the last *Error
307308
func HTTPStatusCode(err error) (int, bool) {
308309
derr, _ := err.(*Error)
309310
if derr != nil {
310311
return derr.HTTPStatusCode(), true
311312
}
313+
314+
jerr, _ := err.(*joinError)
315+
if jerr != nil {
316+
elen := len(jerr.errs)
317+
isErr := true
318+
for i := elen - 1; i >= 0; i-- {
319+
code, isE := HTTPStatusCode(jerr.errs[i])
320+
isErr = isE && isErr
321+
if isE {
322+
return code, isErr
323+
}
324+
}
325+
}
326+
312327
return http.StatusInternalServerError, false
313328
}
314329

@@ -325,6 +340,22 @@ func Message(err error) (string, bool) {
325340
if derr != nil {
326341
return derr.Message(), true
327342
}
343+
344+
jerr, _ := err.(*joinError)
345+
if jerr != nil {
346+
list := make([]string, 0, len(jerr.errs))
347+
isErr := true
348+
for i := range jerr.errs {
349+
msg, ok := Message(jerr.errs[i])
350+
isErr = isErr && ok
351+
if msg == "" {
352+
continue
353+
}
354+
list = append(list, msg)
355+
}
356+
return strings.Join(list, "\n"), isErr
357+
}
358+
328359
return "", false
329360
}
330361

@@ -338,12 +369,23 @@ func WriteHTTP(err error, w http.ResponseWriter) {
338369
}
339370

340371
// Type returns the errType if it's an instance of *Error, -1 otherwise
372+
// In case of joined error, it'll return the type of the last *Error
341373
func Type(err error) errType {
342-
e, ok := err.(*Error)
343-
if !ok {
344-
return errType(-1)
374+
e, _ := err.(*Error)
375+
if e != nil {
376+
return e.Type()
345377
}
346-
return e.Type()
378+
je, _ := err.(*joinError)
379+
if je != nil {
380+
for i := len(je.errs) - 1; i >= 0; i-- {
381+
et := Type(je.errs[i])
382+
if et.Int() != -1 {
383+
return et
384+
}
385+
}
386+
}
387+
388+
return errType(-1)
347389
}
348390

349391
// Type returns the errType as integer if it's an instance of *Error, -1 otherwise
@@ -362,6 +404,15 @@ func HasType(err error, et errType) bool {
362404
return HasType(errors.Unwrap(err), et)
363405
}
364406

407+
je, _ := err.(*joinError)
408+
if je != nil {
409+
for i := 0; i < len(je.errs); i++ {
410+
if HasType(je.errs[i], et) {
411+
return true
412+
}
413+
}
414+
}
415+
365416
if e.Type() == et {
366417
return true
367418
}

helper_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -189,17 +189,17 @@ func TestHTTPStatusCodeMessage(t *testing.T) {
189189
want2: true,
190190
},
191191
}
192-
for _, tt := range tests {
192+
for idx, tt := range tests {
193193
t.Run(tt.name, func(t *testing.T) {
194194
got, got1, got2 := HTTPStatusCodeMessage(tt.args.err)
195195
if got != tt.want {
196-
t.Errorf("HTTPStatusCodeMessage() got = %v, want %v", got, tt.want)
196+
t.Errorf("[%d] HTTPStatusCodeMessage() got = %v, want %v", idx, got, tt.want)
197197
}
198198
if got1 != tt.want1 {
199-
t.Errorf("HTTPStatusCodeMessage() got1 = %v, want %v", got1, tt.want1)
199+
t.Errorf("[%d] HTTPStatusCodeMessage() got1 = %v, want %v", idx, got1, tt.want1)
200200
}
201201
if got2 != tt.want2 {
202-
t.Errorf("HTTPStatusCodeMessage() got2 = %v, want %v", got2, tt.want2)
202+
t.Errorf("[%d] HTTPStatusCodeMessage() got2 = %v, want %v", idx, got2, tt.want2)
203203
}
204204
})
205205
}

mirror.go

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -17,27 +17,20 @@ func As(err error, target interface{}) bool {
1717
return errors.As(err, target)
1818
}
1919

20-
// Join returns an error that combines all the given errors.
21-
// This is the exact implementation found in Go v1.20.
22-
// It will be removed when Go >= v1.20 becomes the LTS version, and would just call the
23-
// native Join after that.
2420
func Join(errs ...error) error {
25-
n := 0
26-
for _, err := range errs {
27-
if err != nil {
28-
n++
29-
}
30-
}
21+
n := len(errs)
3122
if n == 0 {
3223
return nil
3324
}
25+
3426
e := &joinError{
3527
errs: make([]error, 0, n),
3628
}
3729
for _, err := range errs {
38-
if err != nil {
39-
e.errs = append(e.errs, err)
30+
if err == nil {
31+
continue
4032
}
33+
e.errs = append(e.errs, err)
4134
}
4235
return e
4336
}

mirror_test.go

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package errors
33
import (
44
"errors"
55
"fmt"
6+
"net/http"
67
"strings"
78
"testing"
89
)
@@ -123,9 +124,50 @@ func TestAs(t *testing.T) {
123124
}
124125

125126
out := target.Error()
126-
want := "/errors/mirror_test.go:120: type *Error"
127+
want := "/errors/mirror_test.go:121: type *Error"
127128
if !strings.Contains(out, want) {
128129
t.Errorf("Error() = %s, want %s", out, want)
129130
}
131+
}
132+
133+
func TestJoin(t *testing.T) {
134+
joined := Join(
135+
errors.New("[1] std"),
136+
New("[2] custom"),
137+
errors.New("[3] std"),
138+
Validation("[4] validation"),
139+
)
140+
got := fmt.Sprint(joined)
141+
if !strings.Contains(got, "[1] std") ||
142+
!strings.Contains(got, "mirror_test.go:136: [2] custom") ||
143+
!strings.Contains(got, "[3] std") ||
144+
!strings.Contains(got, "mirror_test.go:138: [4] validation") {
145+
t.Error(got)
146+
}
147+
148+
msg, ok := Message(joined)
149+
if ok {
150+
t.Errorf(
151+
"Expected: false, got: %v",
152+
ok,
153+
)
154+
}
155+
expectedMsg := `[2] custom
156+
[4] validation`
157+
if msg != expectedMsg {
158+
t.Error(msg)
159+
}
130160

161+
expectedCode := http.StatusUnprocessableEntity
162+
code, msg, ok := HTTPStatusCodeMessage(joined)
163+
if ok {
164+
t.Errorf("expected false, got: %v", ok)
165+
}
166+
if expectedMsg != msg ||
167+
expectedCode != code {
168+
t.Errorf(
169+
"Expected msg: %s, expected code: %d, got msg: %s, got code: %d",
170+
expectedMsg, expectedCode, msg, code,
171+
)
172+
}
131173
}

0 commit comments

Comments
 (0)