diff --git a/btcjson/chainsvrresults.go b/btcjson/chainsvrresults.go index 11c0483d31..433bdda8eb 100644 --- a/btcjson/chainsvrresults.go +++ b/btcjson/chainsvrresults.go @@ -8,6 +8,7 @@ import ( "bytes" "encoding/hex" "encoding/json" + "fmt" "github.com/btcsuite/btcd/chaincfg/chainhash" @@ -363,6 +364,49 @@ type LocalAddressesResult struct { Score int32 `json:"score"` } +// StringOrArray defines a type that can be used as type that is either a single +// string value or a string array in JSON-RPC commands, depending on the version +// of the chain backend. +type StringOrArray []string + +// MarshalJSON implements the json.Marshaler interface. +func (h StringOrArray) MarshalJSON() ([]byte, error) { + return json.Marshal(h) +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (h *StringOrArray) UnmarshalJSON(data []byte) error { + var unmarshalled interface{} + if err := json.Unmarshal(data, &unmarshalled); err != nil { + return err + } + + switch v := unmarshalled.(type) { + case string: + *h = []string{v} + + case []interface{}: + s := make([]string, len(v)) + for i, e := range v { + str, ok := e.(string) + if !ok { + return fmt.Errorf("invalid string_or_array "+ + "value: %v", unmarshalled) + } + + s[i] = str + } + + *h = s + + default: + return fmt.Errorf("invalid string_or_array value: %v", + unmarshalled) + } + + return nil +} + // GetNetworkInfoResult models the data returned from the getnetworkinfo // command. type GetNetworkInfoResult struct { @@ -380,7 +424,7 @@ type GetNetworkInfoResult struct { RelayFee float64 `json:"relayfee"` IncrementalFee float64 `json:"incrementalfee"` LocalAddresses []LocalAddressesResult `json:"localaddresses"` - Warnings string `json:"warnings"` + Warnings StringOrArray `json:"warnings"` } // GetNodeAddressesResult models the data returned from the getnodeaddresses diff --git a/btcjson/chainsvrresults_test.go b/btcjson/chainsvrresults_test.go index 2566e65f62..122af3dccc 100644 --- a/btcjson/chainsvrresults_test.go +++ b/btcjson/chainsvrresults_test.go @@ -215,3 +215,51 @@ func TestChainSvrMiningInfoResults(t *testing.T) { } } } + +// TestGetNetworkInfoWarnings tests that we can use both a single string value +// and an array of string values for the warnings field in GetNetworkInfoResult. +func TestGetNetworkInfoWarnings(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + result string + expected btcjson.GetNetworkInfoResult + }{ + { + name: "network info with single warning", + result: `{"warnings": "this is a warning"}`, + expected: btcjson.GetNetworkInfoResult{ + Warnings: btcjson.StringOrArray{ + "this is a warning", + }, + }, + }, + { + name: "network info with array of warnings", + result: `{"warnings": ["a", "or", "b"]}`, + expected: btcjson.GetNetworkInfoResult{ + Warnings: btcjson.StringOrArray{ + "a", "or", "b", + }, + }, + }, + } + + t.Logf("Running %d tests", len(tests)) + for i, test := range tests { + var infoResult btcjson.GetNetworkInfoResult + err := json.Unmarshal([]byte(test.result), &infoResult) + if err != nil { + t.Errorf("Test #%d (%s) unexpected error: %v", i, + test.name, err) + continue + } + if !reflect.DeepEqual(infoResult, test.expected) { + t.Errorf("Test #%d (%s) unexpected marhsalled data - "+ + "got %+v, want %+v", i, test.name, infoResult, + test.expected) + continue + } + } +}