Skip to content

Commit

Permalink
modify rest protocol
Browse files Browse the repository at this point in the history
  • Loading branch information
Patrick0308 committed Mar 9, 2020
1 parent 3a46e04 commit fa27aa2
Showing 1 changed file with 19 additions and 10 deletions.
29 changes: 19 additions & 10 deletions protocol/rest/rest_protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/apache/dubbo-go/common"
"github.com/apache/dubbo-go/common/constant"
"github.com/apache/dubbo-go/common/extension"
"github.com/apache/dubbo-go/common/logger"
"github.com/apache/dubbo-go/config"
"github.com/apache/dubbo-go/protocol"
"github.com/apache/dubbo-go/protocol/rest/rest_interface"
Expand Down Expand Up @@ -63,10 +64,14 @@ func (rp *RestProtocol) Export(invoker protocol.Invoker) protocol.Exporter {
url := invoker.GetUrl()
serviceKey := url.ServiceKey()
exporter := NewRestExporter(serviceKey, invoker, rp.ExporterMap())
restConfig := GetRestProviderServiceConfig(strings.TrimPrefix(url.Path, "/"))
restServiceConfig := GetRestProviderServiceConfig(strings.TrimPrefix(url.Path, "/"))
if restServiceConfig == nil {
logger.Errorf("%s service doesn't has provider config", url.Path)
return nil
}
rp.SetExporterMap(serviceKey, exporter)
restServer := rp.getServer(url, restConfig)
restServer.Deploy(invoker, restConfig.RestMethodConfigsMap)
restServer := rp.getServer(url, restServiceConfig.Server)
restServer.Deploy(invoker, restServiceConfig.RestMethodConfigsMap)
return exporter
}

Expand All @@ -78,15 +83,19 @@ func (rp *RestProtocol) Refer(url common.URL) protocol.Invoker {
if t, err := time.ParseDuration(requestTimeoutStr); err == nil {
requestTimeout = t
}
restConfig := GetRestConsumerServiceConfig(strings.TrimPrefix(url.Path, "/"))
restServiceConfig := GetRestConsumerServiceConfig(strings.TrimPrefix(url.Path, "/"))
if restServiceConfig == nil {
logger.Errorf("%s service doesn't has consumer config", url.Path)
return nil
}
restOptions := rest_interface.RestOptions{RequestTimeout: requestTimeout, ConnectTimeout: connectTimeout}
restClient := rp.getClient(restOptions, restConfig)
invoker := NewRestInvoker(url, &restClient, restConfig.RestMethodConfigsMap)
restClient := rp.getClient(restOptions, restServiceConfig.Client)
invoker := NewRestInvoker(url, &restClient, restServiceConfig.RestMethodConfigsMap)
rp.SetInvokers(invoker)
return invoker
}

func (rp *RestProtocol) getServer(url common.URL, restConfig *rest_interface.RestServiceConfig) rest_interface.RestServer {
func (rp *RestProtocol) getServer(url common.URL, serverType string) rest_interface.RestServer {
restServer, ok := rp.serverMap[url.Location]
if !ok {
_, ok := rp.ExporterMap().Load(url.ServiceKey())
Expand All @@ -96,7 +105,7 @@ func (rp *RestProtocol) getServer(url common.URL, restConfig *rest_interface.Res
rp.serverLock.Lock()
restServer, ok = rp.serverMap[url.Location]
if !ok {
restServer = extension.GetNewRestServer(restConfig.Server)
restServer = extension.GetNewRestServer(serverType)
restServer.Start(url)
rp.serverMap[url.Location] = restServer
}
Expand All @@ -106,13 +115,13 @@ func (rp *RestProtocol) getServer(url common.URL, restConfig *rest_interface.Res
return restServer
}

func (rp *RestProtocol) getClient(restOptions rest_interface.RestOptions, restConfig *rest_interface.RestServiceConfig) rest_interface.RestClient {
func (rp *RestProtocol) getClient(restOptions rest_interface.RestOptions, clientType string) rest_interface.RestClient {
restClient, ok := rp.clientMap[restOptions]
rp.clientLock.Lock()
if !ok {
restClient, ok = rp.clientMap[restOptions]
if !ok {
restClient = extension.GetNewRestClient(restConfig.Client, &restOptions)
restClient = extension.GetNewRestClient(clientType, &restOptions)
rp.clientMap[restOptions] = restClient
}
}
Expand Down

0 comments on commit fa27aa2

Please sign in to comment.