diff --git a/CHANGELOG.md b/CHANGELOG.md index 349a2c58558..efcd947d725 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,6 +34,7 @@ This project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.htm - The function signature of the Span `RecordError` method in `go.opentelemetry.io/otel` is updated to no longer take an unused context and instead take a required error value and a variable number of `EventOption`s. (#1254) - Move the `go.opentelemetry.io/otel/api/global` package to `go.opentelemetry.io/otel/global`. (#1262) - Rename correlation context header from `"otcorrelations"` to `"baggage"` to match the OpenTelemetry specification. (#1267) +- Fix `Code.UnmarshalJSON` to work with valid json only. (#1276) ### Removed diff --git a/codes/codes.go b/codes/codes.go index 28393a54400..ec863e5345d 100644 --- a/codes/codes.go +++ b/codes/codes.go @@ -19,6 +19,7 @@ package codes // import "go.opentelemetry.io/otel/codes" import ( + "encoding/json" "fmt" "strconv" ) @@ -45,9 +46,9 @@ var codeToStr = map[Code]string{ } var strToCode = map[string]Code{ - "Unset": Unset, - "Error": Error, - "Ok": Ok, + `"Unset"`: Unset, + `"Error"`: Error, + `"Ok"`: Ok, } // String returns the Code as a string. @@ -70,20 +71,30 @@ func (c *Code) UnmarshalJSON(b []byte) error { return fmt.Errorf("nil receiver passed to UnmarshalJSON") } - if ci, err := strconv.ParseUint(string(b), 10, 32); err == nil { - if ci >= maxCode { - return fmt.Errorf("invalid code: %q", ci) - } - - *c = Code(ci) - return nil + var x interface{} + if err := json.Unmarshal(b, &x); err != nil { + return err } + switch x.(type) { + case string: + if jc, ok := strToCode[string(b)]; ok { + *c = jc + return nil + } + return fmt.Errorf("invalid code: %q", string(b)) + case float64: + if ci, err := strconv.ParseUint(string(b), 10, 32); err == nil { + if ci >= maxCode { + return fmt.Errorf("invalid code: %q", ci) + } - if jc, ok := strToCode[string(b)]; ok { - *c = jc - return nil + *c = Code(ci) + return nil + } + return fmt.Errorf("invalid code: %q", string(b)) + default: + return fmt.Errorf("invalid code: %q", string(b)) } - return fmt.Errorf("invalid code: %q", string(b)) } // MarshalJSON returns c as the JSON encoding of c. diff --git a/codes/codes_test.go b/codes/codes_test.go index 50fc6813481..3518af1d21b 100644 --- a/codes/codes_test.go +++ b/codes/codes_test.go @@ -16,6 +16,7 @@ package codes import ( "bytes" + "encoding/json" "fmt" "testing" ) @@ -61,17 +62,18 @@ func TestCodeUnmarshalJSON(t *testing.T) { want Code }{ {"0", Unset}, - {"Unset", Unset}, + {`"Unset"`, Unset}, {"1", Error}, - {"Error", Error}, + {`"Error"`, Error}, {"2", Ok}, - {"Ok", Ok}, + {`"Ok"`, Ok}, } for _, test := range tests { c := new(Code) *c = Code(maxCode) - if err := c.UnmarshalJSON([]byte(test.input)); err != nil { - t.Fatalf("Code.UnmarshalJSON(%q) errored: %v", test.input, err) + + if err := json.Unmarshal([]byte(test.input), c); err != nil { + t.Fatalf("json.Unmarshal(%q, Code) errored: %v", test.input, err) } if *c != test.want { t.Errorf("failed to unmarshal %q as %v", test.input, test.want) @@ -83,11 +85,15 @@ func TestCodeUnmarshalJSONErrorInvalidData(t *testing.T) { tests := []string{ fmt.Sprintf("%d", maxCode), "Not a code", + "Unset", + "true", + `"Not existing"`, + "", } c := new(Code) for _, test := range tests { - if err := c.UnmarshalJSON([]byte(test)); err == nil { - t.Fatalf("Code.UnmarshalJSON(%q) did not error", test) + if err := json.Unmarshal([]byte(test), c); err == nil { + t.Fatalf("json.Unmarshal(%q, Code) did not error", test) } } } @@ -133,3 +139,30 @@ func TestCodeMarshalJSONErrorInvalid(t *testing.T) { t.Fatal("Code(maxCode).MarshalJSON() returned non-nil value") } } + +func TestRoundTripCodes(t *testing.T) { + tests := []struct { + input Code + }{ + {Unset}, + {Error}, + {Ok}, + } + for _, test := range tests { + c := test.input + out := new(Code) + + b, err := c.MarshalJSON() + if err != nil { + t.Fatalf("Code(%s).MarshalJSON() errored: %v", test.input, err) + } + + if err := out.UnmarshalJSON(b); err != nil { + t.Fatalf("Code.UnmarshalJSON(%q) errored: %v", c, err) + } + + if *out != test.input { + t.Errorf("failed to round trip %q, output was %v", test.input, out) + } + } +}