From fa27aa2386d688f29dbf2153dfd831e3ff6e237f Mon Sep 17 00:00:00 2001 From: Patrick Date: Mon, 9 Mar 2020 10:49:04 +0800 Subject: [PATCH] modify rest protocol --- protocol/rest/rest_protocol.go | 29 +++++++++++++++++++---------- 1 file changed, 19 insertions(+), 10 deletions(-) diff --git a/protocol/rest/rest_protocol.go b/protocol/rest/rest_protocol.go index 20aaee5dca..fdf941a0d5 100644 --- a/protocol/rest/rest_protocol.go +++ b/protocol/rest/rest_protocol.go @@ -27,6 +27,7 @@ import ( "github.com/apache/dubbo-go/common" "github.com/apache/dubbo-go/common/constant" "github.com/apache/dubbo-go/common/extension" + "github.com/apache/dubbo-go/common/logger" "github.com/apache/dubbo-go/config" "github.com/apache/dubbo-go/protocol" "github.com/apache/dubbo-go/protocol/rest/rest_interface" @@ -63,10 +64,14 @@ func (rp *RestProtocol) Export(invoker protocol.Invoker) protocol.Exporter { url := invoker.GetUrl() serviceKey := url.ServiceKey() exporter := NewRestExporter(serviceKey, invoker, rp.ExporterMap()) - restConfig := GetRestProviderServiceConfig(strings.TrimPrefix(url.Path, "/")) + restServiceConfig := GetRestProviderServiceConfig(strings.TrimPrefix(url.Path, "/")) + if restServiceConfig == nil { + logger.Errorf("%s service doesn't has provider config", url.Path) + return nil + } rp.SetExporterMap(serviceKey, exporter) - restServer := rp.getServer(url, restConfig) - restServer.Deploy(invoker, restConfig.RestMethodConfigsMap) + restServer := rp.getServer(url, restServiceConfig.Server) + restServer.Deploy(invoker, restServiceConfig.RestMethodConfigsMap) return exporter } @@ -78,15 +83,19 @@ func (rp *RestProtocol) Refer(url common.URL) protocol.Invoker { if t, err := time.ParseDuration(requestTimeoutStr); err == nil { requestTimeout = t } - restConfig := GetRestConsumerServiceConfig(strings.TrimPrefix(url.Path, "/")) + restServiceConfig := GetRestConsumerServiceConfig(strings.TrimPrefix(url.Path, "/")) + if restServiceConfig == nil { + logger.Errorf("%s service doesn't has consumer config", url.Path) + return nil + } restOptions := rest_interface.RestOptions{RequestTimeout: requestTimeout, ConnectTimeout: connectTimeout} - restClient := rp.getClient(restOptions, restConfig) - invoker := NewRestInvoker(url, &restClient, restConfig.RestMethodConfigsMap) + restClient := rp.getClient(restOptions, restServiceConfig.Client) + invoker := NewRestInvoker(url, &restClient, restServiceConfig.RestMethodConfigsMap) rp.SetInvokers(invoker) return invoker } -func (rp *RestProtocol) getServer(url common.URL, restConfig *rest_interface.RestServiceConfig) rest_interface.RestServer { +func (rp *RestProtocol) getServer(url common.URL, serverType string) rest_interface.RestServer { restServer, ok := rp.serverMap[url.Location] if !ok { _, ok := rp.ExporterMap().Load(url.ServiceKey()) @@ -96,7 +105,7 @@ func (rp *RestProtocol) getServer(url common.URL, restConfig *rest_interface.Res rp.serverLock.Lock() restServer, ok = rp.serverMap[url.Location] if !ok { - restServer = extension.GetNewRestServer(restConfig.Server) + restServer = extension.GetNewRestServer(serverType) restServer.Start(url) rp.serverMap[url.Location] = restServer } @@ -106,13 +115,13 @@ func (rp *RestProtocol) getServer(url common.URL, restConfig *rest_interface.Res return restServer } -func (rp *RestProtocol) getClient(restOptions rest_interface.RestOptions, restConfig *rest_interface.RestServiceConfig) rest_interface.RestClient { +func (rp *RestProtocol) getClient(restOptions rest_interface.RestOptions, clientType string) rest_interface.RestClient { restClient, ok := rp.clientMap[restOptions] rp.clientLock.Lock() if !ok { restClient, ok = rp.clientMap[restOptions] if !ok { - restClient = extension.GetNewRestClient(restConfig.Client, &restOptions) + restClient = extension.GetNewRestClient(clientType, &restOptions) rp.clientMap[restOptions] = restClient } }