diff --git a/common/extension/rest_client.go b/common/extension/rest_client.go index 7096d23810..e63c9a36db 100644 --- a/common/extension/rest_client.go +++ b/common/extension/rest_client.go @@ -22,7 +22,7 @@ import ( ) var ( - restClients = make(map[string]func(restOptions *rest_interface.RestOptions) rest_interface.RestClient) + restClients = make(map[string]func(restOptions *rest_interface.RestOptions) rest_interface.RestClient, 8) ) func SetRestClient(name string, fun func(restOptions *rest_interface.RestOptions) rest_interface.RestClient) { diff --git a/common/extension/rest_server.go b/common/extension/rest_server.go index 06e1757e76..8ba5b65ca5 100644 --- a/common/extension/rest_server.go +++ b/common/extension/rest_server.go @@ -22,7 +22,7 @@ import ( ) var ( - restServers = make(map[string]func() rest_interface.RestServer) + restServers = make(map[string]func() rest_interface.RestServer, 8) ) func SetRestServer(name string, fun func() rest_interface.RestServer) { diff --git a/common/yaml/yaml.go b/common/yaml/yaml.go new file mode 100644 index 0000000000..79e5f67d8f --- /dev/null +++ b/common/yaml/yaml.go @@ -0,0 +1,26 @@ +package yaml + +import ( + "io/ioutil" + "path" +) + +import ( + perrors "github.com/pkg/errors" + "gopkg.in/yaml.v2" +) + +func UnmarshalYMLConfig(yamlFile string, out interface{}) error { + if path.Ext(yamlFile) != ".yml" { + return perrors.Errorf("yamlFile name{%v} suffix must be .yml", yamlFile) + } + confFileStream, err := ioutil.ReadFile(yamlFile) + if err != nil { + return perrors.Errorf("ioutil.ReadFile(file:%s) = error:%v", yamlFile, perrors.WithStack(err)) + } + err = yaml.Unmarshal(confFileStream, out) + if err != nil { + return perrors.Errorf("yaml.Unmarshal() = error:%v", perrors.WithStack(err)) + } + return nil +} diff --git a/config/consumer_config.go b/config/consumer_config.go index cd583954ff..c3fe12d703 100644 --- a/config/consumer_config.go +++ b/config/consumer_config.go @@ -18,8 +18,6 @@ package config import ( - "io/ioutil" - "path" "time" ) @@ -27,12 +25,12 @@ import ( "github.com/creasty/defaults" "github.com/dubbogo/getty" perrors "github.com/pkg/errors" - "gopkg.in/yaml.v2" ) import ( "github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/logger" + "github.com/apache/dubbo-go/common/yaml" ) ///////////////////////// @@ -90,21 +88,11 @@ func ConsumerInit(confConFile string) error { if confConFile == "" { return perrors.Errorf("application configure(consumer) file name is nil") } - - if path.Ext(confConFile) != ".yml" { - return perrors.Errorf("application configure file name{%v} suffix must be .yml", confConFile) - } - - confFileStream, err := ioutil.ReadFile(confConFile) - if err != nil { - return perrors.Errorf("ioutil.ReadFile(file:%s) = error:%v", confConFile, perrors.WithStack(err)) - } consumerConfig = &ConsumerConfig{} - err = yaml.Unmarshal(confFileStream, consumerConfig) + err := yaml.UnmarshalYMLConfig(confConFile, consumerConfig) if err != nil { - return perrors.Errorf("yaml.Unmarshal() = error:%v", perrors.WithStack(err)) + return perrors.Errorf("unmarshalYmlConfig error %v", perrors.WithStack(err)) } - //set method interfaceId & interfaceName for k, v := range consumerConfig.References { //set id for reference diff --git a/config/provider_config.go b/config/provider_config.go index 7bed561d99..d8562863ed 100644 --- a/config/provider_config.go +++ b/config/provider_config.go @@ -17,20 +17,15 @@ package config -import ( - "io/ioutil" - "path" -) - import ( "github.com/creasty/defaults" perrors "github.com/pkg/errors" - "gopkg.in/yaml.v2" ) import ( "github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/logger" + "github.com/apache/dubbo-go/common/yaml" ) ///////////////////////// @@ -81,21 +76,11 @@ func ProviderInit(confProFile string) error { if len(confProFile) == 0 { return perrors.Errorf("application configure(provider) file name is nil") } - - if path.Ext(confProFile) != ".yml" { - return perrors.Errorf("application configure file name{%v} suffix must be .yml", confProFile) - } - - confFileStream, err := ioutil.ReadFile(confProFile) - if err != nil { - return perrors.Errorf("ioutil.ReadFile(file:%s) = error:%v", confProFile, perrors.WithStack(err)) - } providerConfig = &ProviderConfig{} - err = yaml.Unmarshal(confFileStream, providerConfig) + err := yaml.UnmarshalYMLConfig(confProFile, providerConfig) if err != nil { - return perrors.Errorf("yaml.Unmarshal() = error:%v", perrors.WithStack(err)) + return perrors.Errorf("unmarshalYmlConfig error %v", perrors.WithStack(err)) } - //set method interfaceId & interfaceName for k, v := range providerConfig.Services { //set id for reference diff --git a/protocol/rest/rest_client/resty_client.go b/protocol/rest/rest_client/resty_client.go index 88c3cc77e1..13b409817f 100644 --- a/protocol/rest/rest_client/resty_client.go +++ b/protocol/rest/rest_client/resty_client.go @@ -25,12 +25,9 @@ import ( "time" ) -import ( - perrors "github.com/pkg/errors" -) - import ( "github.com/go-resty/resty/v2" + perrors "github.com/pkg/errors" ) import ( diff --git a/protocol/rest/rest_config_initializer.go b/protocol/rest/rest_config_initializer.go index 13a463b049..a80dce98b1 100644 --- a/protocol/rest/rest_config_initializer.go +++ b/protocol/rest/rest_config_initializer.go @@ -45,7 +45,7 @@ func initConsumerRestConfig() { consumerConfigType := config.GetConsumerConfig().RestConfigType consumerConfigReader := extension.GetSingletonRestConfigReader(consumerConfigType) restConsumerConfig := consumerConfigReader.ReadConsumerConfig() - if restConsumerConfig == nil { + if restConsumerConfig == nil || len(restConsumerConfig.RestServiceConfigsMap) == 0 { return } restConsumerServiceConfigMap = make(map[string]*rest_interface.RestServiceConfig, len(restConsumerConfig.RestServiceConfigsMap)) @@ -60,7 +60,7 @@ func initProviderRestConfig() { providerConfigType := config.GetProviderConfig().RestConfigType providerConfigReader := extension.GetSingletonRestConfigReader(providerConfigType) restProviderConfig := providerConfigReader.ReadProviderConfig() - if restProviderConfig == nil { + if restProviderConfig == nil || len(restProviderConfig.RestServiceConfigsMap) == 0 { return } restProviderServiceConfigMap = make(map[string]*rest_interface.RestServiceConfig, len(restProviderConfig.RestServiceConfigsMap)) @@ -125,7 +125,7 @@ func transformMethodConfig(methodConfig *rest_interface.RestMethodConfig) *rest_ } func parseParamsString2Map(params string) (map[int]string, error) { - m := make(map[int]string) + m := make(map[int]string, 8) for _, p := range strings.Split(params, ",") { pa := strings.Split(p, ":") key, err := strconv.Atoi(pa[0]) diff --git a/protocol/rest/rest_config_reader/default_config_reader.go b/protocol/rest/rest_config_reader/default_config_reader.go index 3eb29af8f3..c0a1002409 100644 --- a/protocol/rest/rest_config_reader/default_config_reader.go +++ b/protocol/rest/rest_config_reader/default_config_reader.go @@ -18,20 +18,18 @@ package rest_config_reader import ( - "io/ioutil" "os" - "path" ) import ( perrors "github.com/pkg/errors" - "gopkg.in/yaml.v2" ) import ( "github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/extension" "github.com/apache/dubbo-go/common/logger" + "github.com/apache/dubbo-go/common/yaml" "github.com/apache/dubbo-go/protocol/rest/rest_interface" ) @@ -52,23 +50,14 @@ func NewDefaultConfigReader() *DefaultConfigReader { func (dcr *DefaultConfigReader) ReadConsumerConfig() *rest_interface.RestConsumerConfig { confConFile := os.Getenv(constant.CONF_CONSUMER_FILE_PATH) - if confConFile == "" { - logger.Warnf("rest consumer configure(consumer) file name is nil") - return nil - } - if path.Ext(confConFile) != ".yml" { - logger.Warnf("rest consumer configure file name{%v} suffix must be .yml", confConFile) - return nil - } - confFileStream, err := ioutil.ReadFile(confConFile) - if err != nil { - logger.Warnf("ioutil.ReadFile(file:%s) = error:%v", confConFile, perrors.WithStack(err)) + if len(confConFile) == 0 { + logger.Warnf("[Rest Config] rest consumer configure(consumer) file name is nil") return nil } restConsumerConfig := &rest_interface.RestConsumerConfig{} - err = yaml.Unmarshal(confFileStream, restConsumerConfig) + err := yaml.UnmarshalYMLConfig(confConFile, restConsumerConfig) if err != nil { - logger.Warnf("yaml.Unmarshal() = error:%v", perrors.WithStack(err)) + logger.Errorf("[Rest Config] unmarshal Consumer RestYmlConfig error %v", perrors.WithStack(err)) return nil } return restConsumerConfig @@ -77,26 +66,15 @@ func (dcr *DefaultConfigReader) ReadConsumerConfig() *rest_interface.RestConsume func (dcr *DefaultConfigReader) ReadProviderConfig() *rest_interface.RestProviderConfig { confProFile := os.Getenv(constant.CONF_PROVIDER_FILE_PATH) if len(confProFile) == 0 { - logger.Warnf("rest provider configure(provider) file name is nil") - return nil - } - - if path.Ext(confProFile) != ".yml" { - logger.Warnf("rest provider configure file name{%v} suffix must be .yml", confProFile) - return nil - } - confFileStream, err := ioutil.ReadFile(confProFile) - if err != nil { - logger.Warnf("ioutil.ReadFile(file:%s) = error:%v", confProFile, perrors.WithStack(err)) + logger.Warnf("[Rest Config] rest provider configure(provider) file name is nil") return nil } restProviderConfig := &rest_interface.RestProviderConfig{} - err = yaml.Unmarshal(confFileStream, restProviderConfig) + err := yaml.UnmarshalYMLConfig(confProFile, restProviderConfig) if err != nil { - logger.Warnf("yaml.Unmarshal() = error:%v", perrors.WithStack(err)) + logger.Errorf("[Rest Config] unmarshal Provider RestYmlConfig error %v", perrors.WithStack(err)) return nil } - return restProviderConfig } diff --git a/protocol/rest/rest_invoker.go b/protocol/rest/rest_invoker.go index 83ce07323d..446d122e6d 100644 --- a/protocol/rest/rest_invoker.go +++ b/protocol/rest/rest_invoker.go @@ -99,11 +99,10 @@ func (ri *RestInvoker) Invoke(ctx context.Context, invocation protocol.Invocatio func restStringMapTransform(paramsMap map[int]string, args []interface{}) (map[string]string, error) { resMap := make(map[string]string, len(paramsMap)) for k, v := range paramsMap { - if k < len(args) && k >= 0 { - resMap[v] = fmt.Sprint(args[k]) - } else { + if k >= len(args) || k < 0 { return nil, perrors.Errorf("[Rest Invoke] Index %v is out of bundle", k) } + resMap[v] = fmt.Sprint(args[k]) } return resMap, nil } diff --git a/protocol/rest/rest_protocol.go b/protocol/rest/rest_protocol.go index fdf941a0d5..f4f12e415f 100644 --- a/protocol/rest/rest_protocol.go +++ b/protocol/rest/rest_protocol.go @@ -46,10 +46,10 @@ func init() { type RestProtocol struct { protocol.BaseProtocol - serverMap map[string]rest_interface.RestServer - clientMap map[rest_interface.RestOptions]rest_interface.RestClient serverLock sync.Mutex + serverMap map[string]rest_interface.RestServer clientLock sync.Mutex + clientMap map[rest_interface.RestOptions]rest_interface.RestClient } func NewRestProtocol() *RestProtocol { @@ -97,33 +97,34 @@ func (rp *RestProtocol) Refer(url common.URL) protocol.Invoker { func (rp *RestProtocol) getServer(url common.URL, serverType string) rest_interface.RestServer { restServer, ok := rp.serverMap[url.Location] + if ok { + return restServer + } + _, ok = rp.ExporterMap().Load(url.ServiceKey()) if !ok { - _, ok := rp.ExporterMap().Load(url.ServiceKey()) - if !ok { - panic("[RestProtocol]" + url.ServiceKey() + "is not existing") - } - rp.serverLock.Lock() - restServer, ok = rp.serverMap[url.Location] - if !ok { - restServer = extension.GetNewRestServer(serverType) - restServer.Start(url) - rp.serverMap[url.Location] = restServer - } - rp.serverLock.Unlock() - + panic("[RestProtocol]" + url.ServiceKey() + "is not existing") + } + rp.serverLock.Lock() + restServer, ok = rp.serverMap[url.Location] + if !ok { + restServer = extension.GetNewRestServer(serverType) + restServer.Start(url) + rp.serverMap[url.Location] = restServer } + rp.serverLock.Unlock() return restServer } func (rp *RestProtocol) getClient(restOptions rest_interface.RestOptions, clientType string) rest_interface.RestClient { restClient, ok := rp.clientMap[restOptions] + if ok { + return restClient + } rp.clientLock.Lock() + restClient, ok = rp.clientMap[restOptions] if !ok { - restClient, ok = rp.clientMap[restOptions] - if !ok { - restClient = extension.GetNewRestClient(clientType, &restOptions) - rp.clientMap[restOptions] = restClient - } + restClient = extension.GetNewRestClient(clientType, &restOptions) + rp.clientMap[restOptions] = restClient } rp.clientLock.Unlock() return restClient diff --git a/protocol/rest/rest_server/go_restful_server.go b/protocol/rest/rest_server/go_restful_server.go index 9dc42f832f..f6a867d773 100644 --- a/protocol/rest/rest_server/go_restful_server.go +++ b/protocol/rest/rest_server/go_restful_server.go @@ -91,18 +91,20 @@ func (grs *GoRestfulServer) Deploy(invoker protocol.Invoker, restMethodConfig ma } -func getFunc(methodName string, invoker protocol.Invoker, argsTypes []reflect.Type, replyType reflect.Type, config *rest_interface.RestMethodConfig) func(req *restful.Request, resp *restful.Response) { +func getFunc(methodName string, invoker protocol.Invoker, argsTypes []reflect.Type, + replyType reflect.Type, config *rest_interface.RestMethodConfig) func(req *restful.Request, resp *restful.Response) { return func(req *restful.Request, resp *restful.Response) { var ( err error args []interface{} ) - if (len(argsTypes) == 1 || len(argsTypes) == 2 && replyType == nil) && argsTypes[0].String() == "[]interface {}" { + if (len(argsTypes) == 1 || len(argsTypes) == 2 && replyType == nil) && + argsTypes[0].String() == "[]interface {}" { args = getArgsInterfaceFromRequest(req, config) } else { args = getArgsFromRequest(req, argsTypes, config) } - result := invoker.Invoke(context.Background(), invocation.NewRPCInvocation(methodName, args, make(map[string]string, 0))) + result := invoker.Invoke(context.Background(), invocation.NewRPCInvocation(methodName, args, make(map[string]string))) if result.Error() != nil { err = resp.WriteError(http.StatusInternalServerError, result.Error()) if err != nil { @@ -137,7 +139,7 @@ func (grs *GoRestfulServer) Destroy() { } func getArgsInterfaceFromRequest(req *restful.Request, config *rest_interface.RestMethodConfig) []interface{} { - argsMap := make(map[int]interface{}) + argsMap := make(map[int]interface{}, 8) maxKey := 0 for k, v := range config.PathParamsMap { if maxKey < k { @@ -166,7 +168,6 @@ func getArgsInterfaceFromRequest(req *restful.Request, config *rest_interface.Re if maxKey < config.Body { maxKey = config.Body } - m := make(map[string]interface{}) // TODO read as a slice if err := req.ReadEntity(&m); err != nil { @@ -175,7 +176,6 @@ func getArgsInterfaceFromRequest(req *restful.Request, config *rest_interface.Re argsMap[config.Body] = m } } - args := make([]interface{}, maxKey+1) for k, v := range argsMap { if k >= 0 {