diff --git a/src/Database/Redis/Cluster.hs b/src/Database/Redis/Cluster.hs index 2f577bc2..d1822c8f 100644 --- a/src/Database/Redis/Cluster.hs +++ b/src/Database/Redis/Cluster.hs @@ -36,6 +36,7 @@ import System.IO.Unsafe(unsafeInterleaveIO) import Database.Redis.Protocol(Reply(Error), renderRequest, reply) import qualified Database.Redis.Cluster.Command as CMD +import Network.TLS (ClientParams (..)) -- This module implements a clustered connection whilst maintaining -- compatibility with the original Hedis codebase. In particular it still @@ -100,8 +101,8 @@ instance Exception UnsupportedClusterCommandException newtype CrossSlotException = CrossSlotException [[B.ByteString]] deriving (Show, Typeable) instance Exception CrossSlotException -connect :: [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> IO Connection -connect commandInfos shardMapVar timeoutOpt = do +connect :: Maybe ClientParams -> [CMD.CommandInfo] -> MVar ShardMap -> Maybe Int -> IO Connection +connect mTlsParams commandInfos shardMapVar timeoutOpt = do shardMap <- readMVar shardMapVar stateVar <- newMVar $ Pending [] pipelineVar <- newMVar $ Pipeline stateVar @@ -111,7 +112,18 @@ connect commandInfos shardMapVar timeoutOpt = do nodeConnections shardMap = HM.fromList <$> mapM connectNode (nub $ nodes shardMap) connectNode :: Node -> IO (NodeID, NodeConnection) connectNode (Node n _ host port) = do - ctx <- CC.connect host (CC.PortNumber $ toEnum port) timeoutOpt + ctx0 <- CC.connect host (CC.PortNumber $ toEnum port) timeoutOpt + ctx <- case mTlsParams of + Nothing -> pure ctx0 + Just defaultTlsParams -> do + -- The defaultTlsParams are used to connect to the first + -- host in the cluster, other hosts have different + -- hostnames and so require a different server + -- identification params + let tlsParams = defaultTlsParams { + clientServerIdentification = (host, Char8.pack $ show port) + } + CC.enableTLS tlsParams ctx0 ref <- IOR.newIORef Nothing return (n, NodeConnection ctx ref n) diff --git a/src/Database/Redis/Connection.hs b/src/Database/Redis/Connection.hs index 156662ec..d3ac0c9e 100644 --- a/src/Database/Redis/Connection.hs +++ b/src/Database/Redis/Connection.hs @@ -231,9 +231,9 @@ connectCluster bootstrapConnInfo = do Left e -> throwIO $ ClusterConnectError e Right infos -> do #if MIN_VERSION_resource_pool(0,3,0) - pool <- newPool (defaultPoolConfig (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)) + pool <- newPool (defaultPoolConfig (Cluster.connect (connectTLSParams bootstrapConnInfo) infos shardMapVar Nothing) Cluster.disconnect (realToFrac $ connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo)) #else - pool <- createPool (Cluster.connect infos shardMapVar Nothing) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo) + pool <- createPool (Cluster.connect (connectTLSParams bootstrapConnInfo) infos shardMapVar Nothing) Cluster.disconnect 1 (connectMaxIdleTime bootstrapConnInfo) (connectMaxConnections bootstrapConnInfo) #endif return $ ClusteredConnection shardMapVar pool