From a97a7686de9deb13b78969bddd5d6b39ca89fb77 Mon Sep 17 00:00:00 2001 From: jake Date: Thu, 15 Feb 2024 12:59:24 -0600 Subject: [PATCH] Add `fourmolu` formatting rules (#158) * add: `fourmolu` formatting rules * fix: formatting issue with C guards --- bench/Bench.hs | 224 +++--- core/src/Network/GRPC/LowLevel.hs | 215 +++--- core/src/Network/GRPC/LowLevel/Call.hs | 135 ++-- .../GRPC/LowLevel/Call/Unregistered.hs | 28 +- core/src/Network/GRPC/LowLevel/Client.hs | 487 ++++++------ .../GRPC/LowLevel/Client/Unregistered.hs | 108 +-- .../Network/GRPC/LowLevel/CompletionQueue.hs | 217 +++--- .../GRPC/LowLevel/CompletionQueue/Internal.hs | 138 ++-- .../LowLevel/CompletionQueue/Unregistered.hs | 94 ++- core/src/Network/GRPC/LowLevel/GRPC.hs | 96 +-- .../Network/GRPC/LowLevel/GRPC/MetadataMap.hs | 66 +- core/src/Network/GRPC/LowLevel/Op.hs | 280 +++---- core/src/Network/GRPC/LowLevel/Server.hs | 575 +++++++------- .../GRPC/LowLevel/Server/Unregistered.hs | 227 +++--- core/tests/LowLevelTests.hs | 716 ++++++++++-------- core/tests/LowLevelTests/Op.hs | 28 +- core/tests/Properties.hs | 23 +- core/tests/UnsafeTests.hs | 132 ++-- examples/echo/echo-hs/Echo.hs | 445 ++++++----- examples/echo/echo-hs/EchoClient.hs | 60 +- examples/echo/echo-hs/EchoServer.hs | 61 +- examples/hellos/hellos-client/Main.hs | 102 +-- examples/hellos/hellos-server/Main.hs | 65 +- examples/tutorial/Arithmetic.hs | 447 ++++++----- examples/tutorial/ArithmeticClient.hs | 52 +- examples/tutorial/ArithmeticServer.hs | 68 +- fourmolu.yaml | 51 ++ src/Network/GRPC/HighLevel.hs | 101 ++- src/Network/GRPC/HighLevel/Client.hs | 243 +++--- src/Network/GRPC/HighLevel/Generated.hs | 130 ++-- src/Network/GRPC/HighLevel/Server.hs | 308 ++++---- .../GRPC/HighLevel/Server/Unregistered.hs | 127 ++-- tests/GeneratedTests.hs | 76 +- tests/Properties.hs | 6 +- tests/TestClient.hs | 214 +++--- tests/TestServer.hs | 85 ++- 36 files changed, 3577 insertions(+), 2853 deletions(-) create mode 100644 fourmolu.yaml diff --git a/bench/Bench.hs b/bench/Bench.hs index 43e27152..53e8adf8 100644 --- a/bench/Bench.hs +++ b/bench/Bench.hs @@ -1,28 +1,30 @@ -{-# LANGUAGE BangPatterns #-} -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -import Control.Concurrent -import Control.Concurrent.Async -import Control.Exception (bracket) -import Control.Monad -import Criterion.Main -import Criterion.Types (Config (..)) -import qualified Data.ByteString.Lazy as BL -import Data.Word -import GHC.Generics (Generic) -import Network.GRPC.HighLevel.Server hiding (serverLoop) -import Network.GRPC.HighLevel.Server.Unregistered (serverLoop) -import Network.GRPC.LowLevel -import Network.GRPC.LowLevel.GRPC (threadDelaySecs) -import Proto3.Suite.Class -import Proto3.Suite.Types -import System.Random (randomRIO) - -data AddRequest = AddRequest {addX :: Fixed Word32 - , addY :: Fixed Word32} +import Control.Concurrent +import Control.Concurrent.Async +import Control.Exception (bracket) +import Control.Monad +import Criterion.Main +import Criterion.Types (Config (..)) +import qualified Data.ByteString.Lazy as BL +import Data.Word +import GHC.Generics (Generic) +import Network.GRPC.HighLevel.Server hiding (serverLoop) +import Network.GRPC.HighLevel.Server.Unregistered (serverLoop) +import Network.GRPC.LowLevel +import Network.GRPC.LowLevel.GRPC (threadDelaySecs) +import Proto3.Suite.Class +import Proto3.Suite.Types +import System.Random (randomRIO) + +data AddRequest = AddRequest + { addX :: Fixed Word32 + , addY :: Fixed Word32 + } deriving (Show, Eq, Ord, Generic) instance Message AddRequest @@ -47,53 +49,57 @@ addHandler = UnaryHandler addMethod $ \c -> do let b = payload c - return ( AddResponse $ addX b + addY b - , metadata c - , StatusOk - , StatusDetails "" - ) + return + ( AddResponse $ addX b + addY b + , metadata c + , StatusOk + , StatusDetails "" + ) addClientStreamHandler :: Handler 'ClientStreaming addClientStreamHandler = ClientStreamHandler addClientStreamMethod $ - \_ recv -> do - answer <- go recv 0 - return (Just answer, mempty, StatusOk, "") - where go recv !i = do - req <- recv - case req of - Left _ -> return $ AddResponse i - Right Nothing -> return $ AddResponse i - Right (Just (AddRequest x y)) -> go recv (i+x+y) + \_ recv -> do + answer <- go recv 0 + return (Just answer, mempty, StatusOk, "") + where + go recv !i = do + req <- recv + case req of + Left _ -> return $ AddResponse i + Right Nothing -> return $ AddResponse i + Right (Just (AddRequest x y)) -> go recv (i + x + y) addServerStreamHandler :: Handler 'ServerStreaming addServerStreamHandler = ServerStreamHandler addServerStreamMethod $ - \c send -> do - let AddRequest (Fixed x) y = payload c - replicateM_ (fromIntegral x) $ send $ AddResponse y - return (mempty, StatusOk, "") + \c send -> do + let AddRequest (Fixed x) y = payload c + replicateM_ (fromIntegral x) $ send $ AddResponse y + return (mempty, StatusOk, "") addBiDiHandler :: Handler 'BiDiStreaming addBiDiHandler = BiDiStreamHandler addBiDiMethod (go 0) - where go :: Fixed Word32 -> ServerRWHandler AddRequest AddResponse - go !i c recv send = do - req <- recv - case req of - Left _ -> return (mempty, StatusOk, "") - Right Nothing -> return (mempty, StatusOk, "") - Right (Just (AddRequest x y)) -> do - let curr = i + x + y - void $ send $ AddResponse curr - go curr c recv send - + where + go :: Fixed Word32 -> ServerRWHandler AddRequest AddResponse + go !i c recv send = do + req <- recv + case req of + Left _ -> return (mempty, StatusOk, "") + Right Nothing -> return (mempty, StatusOk, "") + Right (Just (AddRequest x y)) -> do + let curr = i + x + y + void $ send $ AddResponse curr + go curr c recv send serverOpts :: ServerOptions serverOpts = - defaultOptions{optNormalHandlers = [addHandler] - , optClientStreamHandlers = [addClientStreamHandler] - , optServerStreamHandlers = [addServerStreamHandler] - , optBiDiStreamHandlers = [addBiDiHandler]} + defaultOptions + { optNormalHandlers = [addHandler] + , optClientStreamHandlers = [addClientStreamHandler] + , optServerStreamHandlers = [addServerStreamHandler] + , optBiDiStreamHandlers = [addBiDiHandler] + } main :: IO () main = bracket startServer stopServer $ const $ withGRPC $ \grpc -> @@ -113,54 +119,58 @@ main = bracket startServer stopServer $ const $ withGRPC $ \grpc -> , bench "server stream: 10k messages" $ nfIO (addServerStream c rmServerStream 10000) , bench "bidi stream: 50 messages up, 50 down" $ nfIO (bidiStream c rmBiDiStream 50) , bench "bidi stream: 500 message up, 500 down" $ nfIO (bidiStream c rmBiDiStream 500) - , bench "bidi stream: 5000 messages up, 5000 down" $ nfIO (bidiStream c rmBiDiStream 5000)] - - where startServer = do - sThrd <- async $ serverLoop serverOpts - threadDelaySecs 1 - return sThrd - - stopServer sThrd = cancel sThrd >> void (waitCatch sThrd) - - encode = BL.toStrict . toLazyByteString - - addRequest c rmAdd = do - x <- liftM Fixed $ randomRIO (0,1000) - y <- liftM Fixed $ randomRIO (0,1000) - let addEnc = BL.toStrict . toLazyByteString $ AddRequest x y - clientRequest c rmAdd 5 addEnc mempty >>= \case - Left e -> fail $ "Got client error on add request: " ++ show e - Right r -> case fromByteString (rspBody r) of - Left e -> fail $ "failed to decode add response: " ++ show e - Right dec - | dec == AddResponse (x + y) -> return () - | otherwise -> fail $ "Got wrong add answer: " ++ show dec ++ "; expected: " ++ show x ++ " + " ++ show y ++ " = " ++ show (x+y) - - addClientStream c rm i = do - let msg = encode $ AddRequest 1 0 - Right (Just r,_,_,_,_) <- clientWriter c rm 5 mempty $ \send -> do - replicateM_ i $ send msg - let decoded = fromByteString r - when (decoded /= Right (AddResponse (fromIntegral i))) $ - fail $ "clientStream: bad answer: " ++ show decoded ++ "; expected: " ++ show i - - addServerStream c rm i = do - let msg = encode $ AddRequest (fromIntegral i) 2 - Right (_, _, sd) <- clientReader c rm 5 msg mempty $ \_ recv -> - replicateM_ i $ do - Right (Just bs) <- recv - let Right decoded = fromByteString bs - when (decoded /= AddResponse 2) $ - fail $ "serverStream: bad response of " ++ show decoded ++ "; expected 2." - when (sd /= mempty) $ fail $ "bad status details: " ++ show sd - - bidiStream c rm i = do - Right (_, _, sd) <- clientRW c rm 5 mempty $ \_ recv send done -> do - forM_ (take i [2,4..]) $ \n -> do - void $ send $ encode $ AddRequest 1 1 - Right (Just bs) <- recv - let Right decoded = fromByteString bs - when (decoded /= AddResponse n) $ - fail $ "bidiStream: got: " ++ show decoded ++ "expected: " ++ show n - void done - when (sd /= mempty) $ fail $ "bad StatusDetails: " ++ show sd + , bench "bidi stream: 5000 messages up, 5000 down" $ nfIO (bidiStream c rmBiDiStream 5000) + ] + where + startServer = do + sThrd <- async $ serverLoop serverOpts + threadDelaySecs 1 + return sThrd + + stopServer sThrd = cancel sThrd >> void (waitCatch sThrd) + + encode = BL.toStrict . toLazyByteString + + addRequest c rmAdd = do + x <- liftM Fixed $ randomRIO (0, 1000) + y <- liftM Fixed $ randomRIO (0, 1000) + let addEnc = BL.toStrict . toLazyByteString $ AddRequest x y + clientRequest c rmAdd 5 addEnc mempty >>= \case + Left e -> fail $ "Got client error on add request: " ++ show e + Right r -> case fromByteString (rspBody r) of + Left e -> fail $ "failed to decode add response: " ++ show e + Right dec + | dec == AddResponse (x + y) -> return () + | otherwise -> fail $ "Got wrong add answer: " ++ show dec ++ "; expected: " ++ show x ++ " + " ++ show y ++ " = " ++ show (x + y) + + addClientStream c rm i = do + let msg = encode $ AddRequest 1 0 + Right (Just r, _, _, _, _) <- clientWriter c rm 5 mempty $ \send -> do + replicateM_ i $ send msg + let decoded = fromByteString r + when (decoded /= Right (AddResponse (fromIntegral i))) $ + fail $ + "clientStream: bad answer: " ++ show decoded ++ "; expected: " ++ show i + + addServerStream c rm i = do + let msg = encode $ AddRequest (fromIntegral i) 2 + Right (_, _, sd) <- clientReader c rm 5 msg mempty $ \_ recv -> + replicateM_ i $ do + Right (Just bs) <- recv + let Right decoded = fromByteString bs + when (decoded /= AddResponse 2) $ + fail $ + "serverStream: bad response of " ++ show decoded ++ "; expected 2." + when (sd /= mempty) $ fail $ "bad status details: " ++ show sd + + bidiStream c rm i = do + Right (_, _, sd) <- clientRW c rm 5 mempty $ \_ recv send done -> do + forM_ (take i [2, 4 ..]) $ \n -> do + void $ send $ encode $ AddRequest 1 1 + Right (Just bs) <- recv + let Right decoded = fromByteString bs + when (decoded /= AddResponse n) $ + fail $ + "bidiStream: got: " ++ show decoded ++ "expected: " ++ show n + void done + when (sd /= mempty) $ fail $ "bad StatusDetails: " ++ show sd diff --git a/core/src/Network/GRPC/LowLevel.hs b/core/src/Network/GRPC/LowLevel.hs index dad075d6..acbcde23 100644 --- a/core/src/Network/GRPC/LowLevel.hs +++ b/core/src/Network/GRPC/LowLevel.hs @@ -1,124 +1,131 @@ +{-# LANGUAGE RecordWildCards #-} + -- | Low-level safe interface to gRPC. By "safe", we mean: -- 1. all gRPC objects are guaranteed to be cleaned up correctly. -- 2. all functions are thread-safe. -- 3. all functions leave gRPC in a consistent, safe state. -- These guarantees only apply to the functions exported by this module, -- and not to helper functions in submodules that aren't exported here. - -{-# LANGUAGE RecordWildCards #-} - module Network.GRPC.LowLevel ( --- * Important types -GRPC -, withGRPC -, GRPCIOError(..) -, StatusCode(..) - --- * Completion queue utilities -, CompletionQueue -, withCompletionQueue + -- * Important types + GRPC, + withGRPC, + GRPCIOError (..), + StatusCode (..), --- * Calls -, GRPCMethodType(..) -, RegisteredMethod -, MethodPayload -, NormalRequestResult(..) -, MetadataMap(..) -, MethodName(..) -, StatusDetails(..) + -- * Completion queue utilities + CompletionQueue, + withCompletionQueue, --- * Configuration options -, Arg(..) -, CompressionAlgorithm(..) -, CompressionLevel(..) -, Host(..) -, Port(..) + -- * Calls + GRPCMethodType (..), + RegisteredMethod, + MethodPayload, + NormalRequestResult (..), + MetadataMap (..), + MethodName (..), + StatusDetails (..), --- * Server -, ServerConfig(..) -, Server(normalMethods, sstreamingMethods, cstreamingMethods, - bidiStreamingMethods) -, ServerCall(payload, metadata) -, withServer -, serverHandleNormalCall -, ServerHandlerLL -, withServerCall -, serverCallCancel -, serverCallIsExpired -, serverReader -- for client streaming -, ServerReaderHandlerLL -, serverWriter -- for server streaming -, ServerWriterHandlerLL -, serverRW -- for bidirectional streaming -, ServerRWHandlerLL + -- * Configuration options + Arg (..), + CompressionAlgorithm (..), + CompressionLevel (..), + Host (..), + Port (..), --- * Client and Server Auth -, AuthContext -, AuthProperty(..) -, getAuthProperties -, addAuthProperty + -- * Server + ServerConfig (..), + Server ( + normalMethods, + sstreamingMethods, + cstreamingMethods, + bidiStreamingMethods + ), + ServerCall (payload, metadata), + withServer, + serverHandleNormalCall, + ServerHandlerLL, + withServerCall, + serverCallCancel, + serverCallIsExpired, + serverReader, -- for client streaming + ServerReaderHandlerLL, + serverWriter, -- for server streaming + ServerWriterHandlerLL, + serverRW, -- for bidirectional streaming + ServerRWHandlerLL, --- * Server Auth -, ServerSSLConfig(..) -, ProcessMeta -, AuthProcessorResult(..) -, SslClientCertificateRequestType(..) + -- * Client and Server Auth + AuthContext, + AuthProperty (..), + getAuthProperties, + addAuthProperty, --- * Client Auth -, ClientSSLConfig(..) -, ClientSSLKeyCertPair(..) -, ClientMetadataCreate -, ClientMetadataCreateResult(..) -, AuthMetadataContext(..) + -- * Server Auth + ServerSSLConfig (..), + ProcessMeta, + AuthProcessorResult (..), + SslClientCertificateRequestType (..), --- * Client -, ClientConfig(..) -, Client -, ClientCall -, ConnectivityState(..) -, clientConnectivity -, withClient -, clientRegisterMethodNormal -, clientRegisterMethodClientStreaming -, clientRegisterMethodServerStreaming -, clientRegisterMethodBiDiStreaming -, clientRequest -, clientRequestParent -, clientReader -- for server streaming -, clientWriter -- for client streaming -, clientRW -- for bidirectional streaming -, withClientCall -, withClientCallParent -, clientCallCancel + -- * Client Auth + ClientSSLConfig (..), + ClientSSLKeyCertPair (..), + ClientMetadataCreate, + ClientMetadataCreateResult (..), + AuthMetadataContext (..), --- * Ops -, Op(..) -, OpRecvResult(..) + -- * Client + ClientConfig (..), + Client, + ClientCall, + ConnectivityState (..), + clientConnectivity, + withClient, + clientRegisterMethodNormal, + clientRegisterMethodClientStreaming, + clientRegisterMethodServerStreaming, + clientRegisterMethodBiDiStreaming, + clientRequest, + clientRequestParent, + clientReader, -- for server streaming + clientWriter, -- for client streaming + clientRW, -- for bidirectional streaming + withClientCall, + withClientCallParent, + clientCallCancel, --- * Streaming utilities -, StreamSend -, StreamRecv + -- * Ops + Op (..), + OpRecvResult (..), + -- * Streaming utilities + StreamSend, + StreamRecv, ) where -import Network.GRPC.LowLevel.Call -import Network.GRPC.LowLevel.Client -import Network.GRPC.LowLevel.CompletionQueue -import Network.GRPC.LowLevel.GRPC -import Network.GRPC.LowLevel.Op -import Network.GRPC.LowLevel.Server +import Network.GRPC.LowLevel.Call +import Network.GRPC.LowLevel.Client +import Network.GRPC.LowLevel.CompletionQueue +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.Op +import Network.GRPC.LowLevel.Server -import Network.GRPC.Unsafe (ConnectivityState (..)) -import Network.GRPC.Unsafe.ChannelArgs (Arg (..), CompressionAlgorithm (..), - CompressionLevel (..)) -import Network.GRPC.Unsafe.Op (StatusCode (..)) -import Network.GRPC.Unsafe.Security (AuthContext, - AuthMetadataContext (..), - AuthProcessorResult (..), - AuthProperty (..), - ClientMetadataCreate, - ClientMetadataCreateResult (..), - ProcessMeta, - SslClientCertificateRequestType (..), - addAuthProperty, - getAuthProperties) +import Network.GRPC.Unsafe (ConnectivityState (..)) +import Network.GRPC.Unsafe.ChannelArgs ( + Arg (..), + CompressionAlgorithm (..), + CompressionLevel (..), + ) +import Network.GRPC.Unsafe.Op (StatusCode (..)) +import Network.GRPC.Unsafe.Security ( + AuthContext, + AuthMetadataContext (..), + AuthProcessorResult (..), + AuthProperty (..), + ClientMetadataCreate, + ClientMetadataCreateResult (..), + ProcessMeta, + SslClientCertificateRequestType (..), + addAuthProperty, + getAuthProperties, + ) diff --git a/core/src/Network/GRPC/LowLevel/Call.hs b/core/src/Network/GRPC/LowLevel/Call.hs index 9962ce03..e139150a 100644 --- a/core/src/Network/GRPC/LowLevel/Call.hs +++ b/core/src/Network/GRPC/LowLevel/Call.hs @@ -1,37 +1,36 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveFunctor #-} -{-# LANGUAGE FlexibleInstances #-} -{-# LANGUAGE GADTs #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveFunctor #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE NamedFieldPuns #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TypeFamilies #-} -- | This module defines data structures and operations pertaining to registered -- calls; for unregistered call support, see -- `Network.GRPC.LowLevel.Call.Unregistered`. module Network.GRPC.LowLevel.Call where -import Control.Monad.Managed (Managed, managed) -import Control.Exception (bracket) -import Data.ByteString (ByteString) -import Data.ByteString.Char8 (pack) -import Data.List (intersperse) -import Data.String (IsString) -import Foreign.Marshal.Alloc (free, malloc) -import Foreign.Ptr (Ptr, nullPtr) -import Foreign.Storable (Storable, peek) -import Network.GRPC.LowLevel.CompletionQueue.Internal -import Network.GRPC.LowLevel.GRPC (MetadataMap, - grpcDebug) -import qualified Network.GRPC.Unsafe as C -import qualified Network.GRPC.Unsafe.ByteBuffer as C -import qualified Network.GRPC.Unsafe.Op as C -import qualified Network.GRPC.Unsafe.Time as C -import System.Clock +import Control.Exception (bracket) +import Control.Monad.Managed (Managed, managed) +import Data.ByteString (ByteString) +import Data.ByteString.Char8 (pack) +import Data.List (intersperse) +import Data.String (IsString) +import Foreign.Marshal.Alloc (free, malloc) +import Foreign.Ptr (Ptr, nullPtr) +import Foreign.Storable (Storable, peek) +import Network.GRPC.LowLevel.CompletionQueue.Internal +import Network.GRPC.LowLevel.GRPC (MetadataMap, grpcDebug) +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.ByteBuffer as C +import qualified Network.GRPC.Unsafe.Op as C +import qualified Network.GRPC.Unsafe.Time as C +import System.Clock -- | Models the four types of RPC call supported by gRPC (and correspond to -- DataKinds phantom types on RegisteredMethods). @@ -48,12 +47,13 @@ type family MethodPayload a where MethodPayload 'ServerStreaming = ByteString MethodPayload 'BiDiStreaming = () ---TODO: try replacing this class with a plain old function so we don't have the +-- TODO: try replacing this class with a plain old function so we don't have the -- Payloadable constraint everywhere. -extractPayload :: RegisteredMethod mt - -> Ptr C.ByteBuffer - -> IO (MethodPayload mt) +extractPayload :: + RegisteredMethod mt -> + Ptr C.ByteBuffer -> + IO (MethodPayload mt) extractPayload (RegisteredMethodNormal _ _ _) p = peek p >>= C.copyByteBufferToByteString extractPayload (RegisteredMethodClientStreaming _ _ _) _ = return () @@ -85,36 +85,40 @@ endpoint (Host h) (Port p) = Endpoint (h <> ":" <> pack (show p)) -- library. Note that we use a DataKind-ed phantom type to help constrain use of -- different kinds of registered methods. data RegisteredMethod (mt :: GRPCMethodType) where - RegisteredMethodNormal :: MethodName - -> Endpoint - -> C.CallHandle - -> RegisteredMethod 'Normal - RegisteredMethodClientStreaming :: MethodName - -> Endpoint - -> C.CallHandle - -> RegisteredMethod 'ClientStreaming - RegisteredMethodServerStreaming :: MethodName - -> Endpoint - -> C.CallHandle - -> RegisteredMethod 'ServerStreaming - RegisteredMethodBiDiStreaming :: MethodName - -> Endpoint - -> C.CallHandle - -> RegisteredMethod 'BiDiStreaming + RegisteredMethodNormal :: + MethodName -> + Endpoint -> + C.CallHandle -> + RegisteredMethod 'Normal + RegisteredMethodClientStreaming :: + MethodName -> + Endpoint -> + C.CallHandle -> + RegisteredMethod 'ClientStreaming + RegisteredMethodServerStreaming :: + MethodName -> + Endpoint -> + C.CallHandle -> + RegisteredMethod 'ServerStreaming + RegisteredMethodBiDiStreaming :: + MethodName -> + Endpoint -> + C.CallHandle -> + RegisteredMethod 'BiDiStreaming instance Show (RegisteredMethod a) where show (RegisteredMethodNormal x y z) = "RegisteredMethodNormal " - ++ concat (intersperse " " [show x, show y, show z]) + ++ concat (intersperse " " [show x, show y, show z]) show (RegisteredMethodClientStreaming x y z) = "RegisteredMethodClientStreaming " - ++ concat (intersperse " " [show x, show y, show z]) + ++ concat (intersperse " " [show x, show y, show z]) show (RegisteredMethodServerStreaming x y z) = "RegisteredMethodServerStreaming " - ++ concat (intersperse " " [show x, show y, show z]) + ++ concat (intersperse " " [show x, show y, show z]) show (RegisteredMethodBiDiStreaming x y z) = "RegisteredMethodBiDiStreaming " - ++ concat (intersperse " " [show x, show y, show z]) + ++ concat (intersperse " " [show x, show y, show z]) methodName :: RegisteredMethod mt -> MethodName methodName (RegisteredMethodNormal x _ _) = x @@ -135,14 +139,14 @@ methodHandle (RegisteredMethodServerStreaming _ _ x) = x methodHandle (RegisteredMethodBiDiStreaming _ _ x) = x methodType :: RegisteredMethod mt -> GRPCMethodType -methodType (RegisteredMethodNormal _ _ _) = Normal +methodType (RegisteredMethodNormal _ _ _) = Normal methodType (RegisteredMethodClientStreaming _ _ _) = ClientStreaming methodType (RegisteredMethodServerStreaming _ _ _) = ServerStreaming -methodType (RegisteredMethodBiDiStreaming _ _ _) = BiDiStreaming +methodType (RegisteredMethodBiDiStreaming _ _ _) = BiDiStreaming -- | Represents one GRPC call (i.e. request) on the client. -- This is used to associate send/receive 'Op's with a request. -data ClientCall = ClientCall { unsafeCC :: C.Call } +data ClientCall = ClientCall {unsafeCC :: C.Call} clientCallCancel :: ClientCall -> IO () clientCallCancel cc = C.grpcCallCancel (unsafeCC cc) C.reserved @@ -150,12 +154,13 @@ clientCallCancel cc = C.grpcCallCancel (unsafeCC cc) C.reserved -- | Represents one registered GRPC call on the server. Contains pointers to all -- the C state needed to respond to a registered call. data ServerCall a = ServerCall - { unsafeSC :: C.Call - , callCQ :: CompletionQueue - , metadata :: MetadataMap - , payload :: a - , callDeadline :: TimeSpec - } deriving (Functor, Show) + { unsafeSC :: C.Call + , callCQ :: CompletionQueue + , metadata :: MetadataMap + , payload :: a + , callDeadline :: TimeSpec + } + deriving (Functor, Show) serverCallCancel :: ServerCall a -> C.StatusCode -> String -> IO () serverCallCancel sc code reason = @@ -164,19 +169,19 @@ serverCallCancel sc code reason = -- | NB: For now, we've assumed that the method type is all the info we need to -- decide the server payload handling method. payloadHandling :: GRPCMethodType -> C.ServerRegisterMethodPayloadHandling -payloadHandling Normal = C.SrmPayloadReadInitialByteBuffer +payloadHandling Normal = C.SrmPayloadReadInitialByteBuffer payloadHandling ClientStreaming = C.SrmPayloadNone payloadHandling ServerStreaming = C.SrmPayloadReadInitialByteBuffer -payloadHandling BiDiStreaming = C.SrmPayloadNone +payloadHandling BiDiStreaming = C.SrmPayloadNone -- | Optionally allocate a managed byte buffer for a payload, depending on the -- given method type. If no payload is needed, the returned pointer is null mgdPayload :: GRPCMethodType -> Managed (Ptr C.ByteBuffer) mgdPayload mt | payloadHandling mt == C.SrmPayloadNone = return nullPtr - | otherwise = managed C.withByteBufferPtr + | otherwise = managed C.withByteBufferPtr -mgdPtr :: forall a. Storable a => Managed (Ptr a) +mgdPtr :: forall a. (Storable a) => Managed (Ptr a) mgdPtr = managed (bracket malloc free) serverCallIsExpired :: ServerCall a -> IO Bool @@ -212,7 +217,7 @@ destroyClientCall cc = do C.grpcCallUnref (unsafeCC cc) destroyServerCall :: ServerCall a -> IO () -destroyServerCall sc@ServerCall{ unsafeSC } = do +destroyServerCall sc@ServerCall{unsafeSC} = do grpcDebug "destroyServerCall(R): entered." debugServerCall sc grpcDebug $ "Destroying server-side call object: " ++ show unsafeSC diff --git a/core/src/Network/GRPC/LowLevel/Call/Unregistered.hs b/core/src/Network/GRPC/LowLevel/Call/Unregistered.hs index 675ee2a0..29e1ba49 100644 --- a/core/src/Network/GRPC/LowLevel/Call/Unregistered.hs +++ b/core/src/Network/GRPC/LowLevel/Call/Unregistered.hs @@ -2,23 +2,25 @@ module Network.GRPC.LowLevel.Call.Unregistered where -import qualified Network.GRPC.LowLevel.Call as Reg -import Network.GRPC.LowLevel.CompletionQueue -import Network.GRPC.LowLevel.GRPC (MetadataMap, - grpcDebug) -import qualified Network.GRPC.Unsafe as C -import qualified Network.GRPC.Unsafe.Op as C -import System.Clock (TimeSpec) +import qualified Network.GRPC.LowLevel.Call as Reg +import Network.GRPC.LowLevel.CompletionQueue +import Network.GRPC.LowLevel.GRPC ( + MetadataMap, + grpcDebug, + ) +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.Op as C +import System.Clock (TimeSpec) -- | Represents one unregistered GRPC call on the server. Contains pointers to -- all the C state needed to respond to an unregistered call. data ServerCall = ServerCall - { unsafeSC :: C.Call - , callCQ :: CompletionQueue - , metadata :: MetadataMap - , callDeadline :: TimeSpec - , callMethod :: Reg.MethodName - , callHost :: Reg.Host + { unsafeSC :: C.Call + , callCQ :: CompletionQueue + , metadata :: MetadataMap + , callDeadline :: TimeSpec + , callMethod :: Reg.MethodName + , callHost :: Reg.Host } convertCall :: ServerCall -> Reg.ServerCall () diff --git a/core/src/Network/GRPC/LowLevel/Client.hs b/core/src/Network/GRPC/LowLevel/Client.hs index 5f43764f..b617d914 100644 --- a/core/src/Network/GRPC/LowLevel/Client.hs +++ b/core/src/Network/GRPC/LowLevel/Client.hs @@ -1,81 +1,84 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} -- | This module defines data structures and operations pertaining to registered -- clients using registered calls; for unregistered support, see -- `Network.GRPC.LowLevel.Client.Unregistered`. module Network.GRPC.LowLevel.Client where -import Control.Exception (bracket) -import Control.Concurrent.MVar -import Control.Monad -import Control.Monad.IO.Class -import Control.Monad.Trans.Except -import qualified Data.ByteString as B -import Data.ByteString (ByteString) -import Data.Maybe -import Network.GRPC.LowLevel.Call -import Network.GRPC.LowLevel.CompletionQueue -import Network.GRPC.LowLevel.GRPC -import Network.GRPC.LowLevel.Op -import qualified Network.GRPC.Unsafe as C -import qualified Network.GRPC.Unsafe.ChannelArgs as C -import qualified Network.GRPC.Unsafe.Constants as C -import qualified Network.GRPC.Unsafe.Op as C -import qualified Network.GRPC.Unsafe.Security as C -import qualified Network.GRPC.Unsafe.Time as C +import Control.Concurrent.MVar +import Control.Exception (bracket) +import Control.Monad +import Control.Monad.IO.Class +import Control.Monad.Trans.Except +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import Data.Maybe +import Network.GRPC.LowLevel.Call +import Network.GRPC.LowLevel.CompletionQueue +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.Op +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.ChannelArgs as C +import qualified Network.GRPC.Unsafe.Constants as C +import qualified Network.GRPC.Unsafe.Op as C +import qualified Network.GRPC.Unsafe.Security as C +import qualified Network.GRPC.Unsafe.Time as C -- | Represents the context needed to perform client-side gRPC operations. -data Client = Client {clientChannel :: C.Channel, - clientCQ :: CompletionQueue, - clientConfig :: ClientConfig - } +data Client = Client + { clientChannel :: C.Channel + , clientCQ :: CompletionQueue + , clientConfig :: ClientConfig + } data ClientSSLKeyCertPair = ClientSSLKeyCertPair - { clientPrivateKey :: FilePath, - clientCert :: FilePath - } deriving Show + { clientPrivateKey :: FilePath + , clientCert :: FilePath + } + deriving (Show) -- | SSL configuration for the client. It's perfectly acceptable for both fields -- to be 'Nothing', in which case default fallbacks will be used for the server -- root cert. data ClientSSLConfig = ClientSSLConfig - {serverRootCert :: Maybe FilePath, - -- ^ Path to the server root certificate. If 'Nothing', gRPC will attempt to - -- fall back to a default. - clientSSLKeyCertPair :: Maybe ClientSSLKeyCertPair, - -- ^ The client's private key and cert, if available. - clientMetadataPlugin :: Maybe C.ClientMetadataCreate - -- ^ Optional plugin for attaching additional metadata to each call. + { serverRootCert :: Maybe FilePath + -- ^ Path to the server root certificate. If 'Nothing', gRPC will attempt to + -- fall back to a default. + , clientSSLKeyCertPair :: Maybe ClientSSLKeyCertPair + -- ^ The client's private key and cert, if available. + , clientMetadataPlugin :: Maybe C.ClientMetadataCreate + -- ^ Optional plugin for attaching additional metadata to each call. } -- | Configuration necessary to set up a client. +data ClientConfig = ClientConfig + { clientServerEndpoint :: Endpoint + , clientArgs :: [C.Arg] + -- ^ Optional arguments for setting up the + -- channel on the client. Supplying an empty + -- list will cause the channel to use gRPC's + -- default options. + , clientSSLConfig :: Maybe ClientSSLConfig + -- ^ If 'Nothing', the client will use an + -- insecure connection to the server. + -- Otherwise, will use the supplied config to + -- connect using SSL. + , clientAuthority :: Maybe ByteString + -- ^ If 'Nothing', the :authority pseudo-header will + -- be the endpoint host. Otherwise, the :authority + -- pseudo-header will be set to the supplied value. + } -data ClientConfig = ClientConfig {clientServerEndpoint :: Endpoint, - clientArgs :: [C.Arg], - -- ^ Optional arguments for setting up the - -- channel on the client. Supplying an empty - -- list will cause the channel to use gRPC's - -- default options. - clientSSLConfig :: Maybe ClientSSLConfig, - -- ^ If 'Nothing', the client will use an - -- insecure connection to the server. - -- Otherwise, will use the supplied config to - -- connect using SSL. - clientAuthority :: Maybe ByteString - -- ^ If 'Nothing', the :authority pseudo-header will - -- be the endpoint host. Otherwise, the :authority - -- pseudo-header will be set to the supplied value. - } - -addMetadataCreds :: C.ChannelCredentials - -> Maybe C.ClientMetadataCreate - -> IO C.ChannelCredentials +addMetadataCreds :: + C.ChannelCredentials -> + Maybe C.ClientMetadataCreate -> + IO C.ChannelCredentials addMetadataCreds c Nothing = return c addMetadataCreds c (Just create) = do callCreds <- C.createCustomCallCredentials create @@ -88,18 +91,21 @@ createChannel ClientConfig{..} chanargs = C.withInsecureChannelCredentials $ \creds -> C.grpcChannelCreate e creds chanargs Just (ClientSSLConfig rootCertPath Nothing plugin) -> - do rootCert <- mapM B.readFile rootCertPath - C.withChannelCredentials rootCert Nothing Nothing $ \creds -> do - creds' <- addMetadataCreds creds plugin - C.grpcChannelCreate e creds' chanargs + do + rootCert <- mapM B.readFile rootCertPath + C.withChannelCredentials rootCert Nothing Nothing $ \creds -> do + creds' <- addMetadataCreds creds plugin + C.grpcChannelCreate e creds' chanargs Just (ClientSSLConfig x (Just (ClientSSLKeyCertPair y z)) plugin) -> - do rootCert <- mapM B.readFile x - privKey <- Just <$> B.readFile y - clientCert <- Just <$> B.readFile z - C.withChannelCredentials rootCert privKey clientCert $ \creds -> do - creds' <- addMetadataCreds creds plugin - C.grpcChannelCreate e creds' chanargs - where (Endpoint e) = clientServerEndpoint + do + rootCert <- mapM B.readFile x + privKey <- Just <$> B.readFile y + clientCert <- Just <$> B.readFile z + C.withChannelCredentials rootCert privKey clientCert $ \creds -> do + creds' <- addMetadataCreds creds plugin + C.grpcChannelCreate e creds' chanargs + where + (Endpoint e) = clientServerEndpoint createClient :: GRPC -> ClientConfig -> IO Client createClient grpc clientConfig = @@ -115,63 +121,71 @@ destroyClient Client{..} = do grpcDebug "destroyClient: shutting down CQ." shutdownResult <- shutdownCompletionQueue clientCQ case shutdownResult of - Left x -> do putStrLn $ "Failed to stop client CQ: " ++ show x - putStrLn $ "Trying to shut down anyway." + Left x -> do + putStrLn $ "Failed to stop client CQ: " ++ show x + putStrLn $ "Trying to shut down anyway." Right _ -> return () withClient :: GRPC -> ClientConfig -> (Client -> IO a) -> IO a -withClient grpc config = bracket (createClient grpc config) - (\c -> grpcDebug "withClient: destroying." - >> destroyClient c) +withClient grpc config = + bracket + (createClient grpc config) + ( \c -> + grpcDebug "withClient: destroying." + >> destroyClient c + ) clientConnectivity :: Client -> IO C.ConnectivityState clientConnectivity Client{..} = C.grpcChannelCheckConnectivityState clientChannel False ---TODO: We should probably also register client methods on startup. +-- TODO: We should probably also register client methods on startup. -- | Register a method on the client so that we can call it with -- 'clientRequest'. -clientRegisterMethod :: Client - -> MethodName - -> IO (C.CallHandle) +clientRegisterMethod :: + Client -> + MethodName -> + IO (C.CallHandle) clientRegisterMethod Client{..} meth = do let host = fromMaybe (unEndpoint (clientServerEndpoint clientConfig)) (clientAuthority clientConfig) - C.grpcChannelRegisterCall clientChannel - (unMethodName meth) - host - C.reserved - - -clientRegisterMethodNormal :: Client - -> MethodName - -> IO (RegisteredMethod 'Normal) + C.grpcChannelRegisterCall + clientChannel + (unMethodName meth) + host + C.reserved + +clientRegisterMethodNormal :: + Client -> + MethodName -> + IO (RegisteredMethod 'Normal) clientRegisterMethodNormal c meth = do let e = clientServerEndpoint (clientConfig c) h <- clientRegisterMethod c meth return $ RegisteredMethodNormal meth e h - -clientRegisterMethodClientStreaming :: Client - -> MethodName - -> IO (RegisteredMethod 'ClientStreaming) +clientRegisterMethodClientStreaming :: + Client -> + MethodName -> + IO (RegisteredMethod 'ClientStreaming) clientRegisterMethodClientStreaming c meth = do let e = clientServerEndpoint (clientConfig c) h <- clientRegisterMethod c meth - return $ RegisteredMethodClientStreaming meth e h + return $ RegisteredMethodClientStreaming meth e h -clientRegisterMethodServerStreaming :: Client - -> MethodName - -> IO (RegisteredMethod 'ServerStreaming) +clientRegisterMethodServerStreaming :: + Client -> + MethodName -> + IO (RegisteredMethod 'ServerStreaming) clientRegisterMethodServerStreaming c meth = do let e = clientServerEndpoint (clientConfig c) h <- clientRegisterMethod c meth return $ RegisteredMethodServerStreaming meth e h - -clientRegisterMethodBiDiStreaming :: Client - -> MethodName - -> IO (RegisteredMethod 'BiDiStreaming) +clientRegisterMethodBiDiStreaming :: + Client -> + MethodName -> + IO (RegisteredMethod 'BiDiStreaming) clientRegisterMethodBiDiStreaming c meth = do let e = clientServerEndpoint (clientConfig c) h <- clientRegisterMethod c meth @@ -180,48 +194,57 @@ clientRegisterMethodBiDiStreaming c meth = do -- | Create a new call on the client for a registered method. -- Returns 'Left' if the CQ is shutting down or if the job to create a call -- timed out. -clientCreateCall :: Client - -> RegisteredMethod mt - -> TimeoutSeconds - -> IO (Either GRPCIOError ClientCall) +clientCreateCall :: + Client -> + RegisteredMethod mt -> + TimeoutSeconds -> + IO (Either GRPCIOError ClientCall) clientCreateCall c rm ts = clientCreateCallParent c rm ts Nothing -- | For servers that act as clients to other gRPC servers, this version creates -- a client call with an optional parent server call. This allows for cascading -- call cancellation from the `ServerCall` to the `ClientCall`. -clientCreateCallParent :: Client - -> RegisteredMethod mt - -> TimeoutSeconds - -> Maybe (ServerCall a) - -- ^ Optional parent call for cascading cancellation. - -> IO (Either GRPCIOError ClientCall) +clientCreateCallParent :: + Client -> + RegisteredMethod mt -> + TimeoutSeconds -> + -- | Optional parent call for cascading cancellation. + Maybe (ServerCall a) -> + IO (Either GRPCIOError ClientCall) clientCreateCallParent Client{..} rm timeout parent = do C.withDeadlineSeconds timeout $ \deadline -> do - channelCreateCall clientChannel parent C.propagateDefaults - clientCQ (methodHandle rm) deadline + channelCreateCall + clientChannel + parent + C.propagateDefaults + clientCQ + (methodHandle rm) + deadline -- | Handles safe creation and cleanup of a client call -withClientCall :: Client - -> RegisteredMethod mt - -> TimeoutSeconds - -> (ClientCall -> IO (Either GRPCIOError a)) - -> IO (Either GRPCIOError a) +withClientCall :: + Client -> + RegisteredMethod mt -> + TimeoutSeconds -> + (ClientCall -> IO (Either GRPCIOError a)) -> + IO (Either GRPCIOError a) withClientCall cl rm tm = withClientCallParent cl rm tm Nothing -- | Handles safe creation and cleanup of a client call, with an optional parent -- call parameter. This allows for cancellation to cascade from the parent -- `ServerCall` to the created `ClientCall`. Obviously, this is only useful if -- the given gRPC client is also a server. -withClientCallParent :: Client - -> RegisteredMethod mt - -> TimeoutSeconds - -> Maybe (ServerCall b) - -- ^ Optional parent call for cascading cancellation - -> (ClientCall -> IO (Either GRPCIOError a)) - -> IO (Either GRPCIOError a) +withClientCallParent :: + Client -> + RegisteredMethod mt -> + TimeoutSeconds -> + -- | Optional parent call for cascading cancellation + Maybe (ServerCall b) -> + (ClientCall -> IO (Either GRPCIOError a)) -> + IO (Either GRPCIOError a) withClientCallParent cl rm tm parent f = bracket (clientCreateCallParent cl rm tm parent) cleanup $ \case - Left e -> return (Left e) + Left e -> return (Left e) Right c -> f c where cleanup (Left _) = pure () @@ -232,21 +255,25 @@ withClientCallParent cl rm tm parent f = data NormalRequestResult = NormalRequestResult { rspBody :: ByteString - , initMD :: MetadataMap -- ^ initial metadata - , trailMD :: MetadataMap -- ^ trailing metadata + , initMD :: MetadataMap + -- ^ initial metadata + , trailMD :: MetadataMap + -- ^ trailing metadata , rspCode :: C.StatusCode , details :: StatusDetails } deriving (Show, Eq) -- | Function for assembling call result when the 'MethodType' is 'Normal'. -compileNormalRequestResults :: [OpRecvResult] - -> Either GRPCIOError NormalRequestResult +compileNormalRequestResults :: + [OpRecvResult] -> + Either GRPCIOError NormalRequestResult compileNormalRequestResults - [OpRecvInitialMetadataResult m, - OpRecvMessageResult (Just body), - OpRecvStatusOnClientResult m2 status details] - = Right $ NormalRequestResult body m m2 status (StatusDetails details) + [ OpRecvInitialMetadataResult m + , OpRecvMessageResult (Just body) + , OpRecvStatusOnClientResult m2 status details + ] = + Right $ NormalRequestResult body m m2 status (StatusDetails details) compileNormalRequestResults x = case extractStatusInfo x of Nothing -> Left GRPCIOUnknownError @@ -258,23 +285,31 @@ compileNormalRequestResults x = -- | First parameter is initial server metadata. type ClientReaderHandler = ClientCall -> MetadataMap -> StreamRecv ByteString -> IO () -type ClientReaderResult = (MetadataMap, C.StatusCode, StatusDetails) - -clientReader :: Client - -> RegisteredMethod 'ServerStreaming - -> TimeoutSeconds - -> ByteString -- ^ The body of the request - -> MetadataMap -- ^ Metadata to send with the request - -> ClientReaderHandler - -> IO (Either GRPCIOError ClientReaderResult) -clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f = + +type ClientReaderResult = (MetadataMap, C.StatusCode, StatusDetails) + +clientReader :: + Client -> + RegisteredMethod 'ServerStreaming -> + TimeoutSeconds -> + -- | The body of the request + ByteString -> + -- | Metadata to send with the request + MetadataMap -> + ClientReaderHandler -> + IO (Either GRPCIOError ClientReaderResult) +clientReader cl@Client{clientCQ = cq} rm tm body initMeta f = withClientCall cl rm tm go where go cc@(unsafeCC -> c) = runExceptT $ do - void $ runOps' c cq [ OpSendInitialMetadata initMeta - , OpSendMessage body - , OpSendCloseFromClient - ] + void $ + runOps' + c + cq + [ OpSendInitialMetadata initMeta + , OpSendMessage body + , OpSendCloseFromClient + ] srvMD <- recvInitialMetadata c cq liftIO $ f cc srvMD (streamRecvPrim c cq) recvStatusOnClient c cq @@ -283,23 +318,34 @@ clientReader cl@Client{ clientCQ = cq } rm tm body initMeta f = -- clientWriter (client side of client streaming mode) type ClientWriterHandler = StreamSend ByteString -> IO () -type ClientWriterResult = (Maybe ByteString, MetadataMap, MetadataMap, - C.StatusCode, StatusDetails) - -clientWriter :: Client - -> RegisteredMethod 'ClientStreaming - -> TimeoutSeconds - -> MetadataMap -- ^ Initial client metadata - -> ClientWriterHandler - -> IO (Either GRPCIOError ClientWriterResult) +type ClientWriterResult = + ( Maybe ByteString + , MetadataMap + , MetadataMap + , C.StatusCode + , StatusDetails + ) + +clientWriter :: + Client -> + RegisteredMethod 'ClientStreaming -> + TimeoutSeconds -> + -- | Initial client metadata + MetadataMap -> + ClientWriterHandler -> + IO (Either GRPCIOError ClientWriterResult) clientWriter cl rm tm initMeta = withClientCall cl rm tm . clientWriterCmn cl initMeta -clientWriterCmn :: Client -- ^ The active client - -> MetadataMap -- ^ Initial client metadata - -> ClientWriterHandler - -> ClientCall -- ^ The active client call - -> IO (Either GRPCIOError ClientWriterResult) +clientWriterCmn :: + -- | The active client + Client -> + -- | Initial client metadata + MetadataMap -> + ClientWriterHandler -> + -- | The active client call + ClientCall -> + IO (Either GRPCIOError ClientWriterResult) clientWriterCmn (clientCQ -> cq) initMeta f (unsafeCC -> c) = runExceptT $ do sendInitialMetadata c cq initMeta @@ -307,40 +353,42 @@ clientWriterCmn (clientCQ -> cq) initMeta f (unsafeCC -> c) = sendSingle c cq OpSendCloseFromClient let ops = [OpRecvInitialMetadata, OpRecvMessage, OpRecvStatusOnClient] runOps' c cq ops >>= \case - CWRFinal mmsg initMD trailMD st ds - -> return (mmsg, initMD, trailMD, st, ds) + CWRFinal mmsg initMD trailMD st ds -> + return (mmsg, initMD, trailMD, st, ds) _ -> throwE (GRPCIOInternalUnexpectedRecv "clientWriter") -pattern CWRFinal :: Maybe ByteString - -> MetadataMap - -> MetadataMap - -> C.StatusCode - -> StatusDetails - -> [OpRecvResult] -pattern CWRFinal mmsg initMD trailMD st ds - <- [ OpRecvInitialMetadataResult initMD - , OpRecvMessageResult mmsg - , OpRecvStatusOnClientResult trailMD st (StatusDetails -> ds) - ] +pattern CWRFinal :: + Maybe ByteString -> + MetadataMap -> + MetadataMap -> + C.StatusCode -> + StatusDetails -> + [OpRecvResult] +pattern CWRFinal mmsg initMD trailMD st ds <- + [ OpRecvInitialMetadataResult initMD + , OpRecvMessageResult mmsg + , OpRecvStatusOnClientResult trailMD st (StatusDetails -> ds) + ] -------------------------------------------------------------------------------- -- clientRW (client side of bidirectional streaming mode) -type ClientRWHandler - = ClientCall - -> IO (Either GRPCIOError MetadataMap) - -> StreamRecv ByteString - -> StreamSend ByteString - -> WritesDone - -> IO () +type ClientRWHandler = + ClientCall -> + IO (Either GRPCIOError MetadataMap) -> + StreamRecv ByteString -> + StreamSend ByteString -> + WritesDone -> + IO () type ClientRWResult = (MetadataMap, C.StatusCode, StatusDetails) -clientRW :: Client - -> RegisteredMethod 'BiDiStreaming - -> TimeoutSeconds - -> MetadataMap - -> ClientRWHandler - -> IO (Either GRPCIOError ClientRWResult) +clientRW :: + Client -> + RegisteredMethod 'BiDiStreaming -> + TimeoutSeconds -> + MetadataMap -> + ClientRWHandler -> + IO (Either GRPCIOError ClientRWResult) clientRW cl rm tm initMeta f = withClientCall cl rm tm (\cc -> clientRW' cl cc initMeta f) @@ -349,11 +397,12 @@ clientRW cl rm tm initMeta f = -- for the half-close, after all threads have completed writing. TODO: It'd be -- nice to find a way to type-enforce this usage pattern rather than accomplish -- it via usage convention and documentation. -clientRW' :: Client - -> ClientCall - -> MetadataMap - -> ClientRWHandler - -> IO (Either GRPCIOError ClientRWResult) +clientRW' :: + Client -> + ClientCall -> + MetadataMap -> + ClientRWHandler -> + IO (Either GRPCIOError ClientRWResult) clientRW' (clientCQ -> cq) cc@(unsafeCC -> c) initMeta f = runExceptT $ do sendInitialMetadata c cq initMeta @@ -390,15 +439,17 @@ clientRW' (clientCQ -> cq) cc@(unsafeCC -> c) initMeta f = runExceptT $ do let getMD = modifyMVar mdmv $ \case Just emd -> return (Just emd, emd) - Nothing -> do -- getMD invoked before recv + Nothing -> do + -- getMD invoked before recv emd <- runExceptT (recvInitialMetadata c cq) return (Just emd, emd) recv = modifyMVar mdmv $ \case Just emd -> (Just emd,) <$> streamRecvPrim c cq - Nothing -> -- recv invoked before getMD + Nothing -> + -- recv invoked before getMD runExceptT (recvInitialMsgMD c cq) >>= \case - Left e -> return (Just (Left e), Left e) + Left e -> return (Just (Left e), Left e) Right (mbs, md) -> return (Just (Right md), Right mbs) send = streamSendPrim c cq @@ -422,38 +473,40 @@ clientRW' (clientCQ -> cq) cc@(unsafeCC -> c) initMeta f = runExceptT $ do -- | Make a request of the given method with the given body. Returns the -- server's response. -clientRequest - :: Client - -> RegisteredMethod 'Normal - -> TimeoutSeconds - -> ByteString - -- ^ The body of the request - -> MetadataMap - -- ^ Metadata to send with the request - -> IO (Either GRPCIOError NormalRequestResult) +clientRequest :: + Client -> + RegisteredMethod 'Normal -> + TimeoutSeconds -> + -- | The body of the request + ByteString -> + -- | Metadata to send with the request + MetadataMap -> + IO (Either GRPCIOError NormalRequestResult) clientRequest c = clientRequestParent c Nothing -- | Like 'clientRequest', but allows the user to supply an optional parent -- call, so that call cancellation can be propagated from the parent to the -- child. This is intended for servers that call other servers. -clientRequestParent - :: Client - -> Maybe (ServerCall a) - -- ^ optional parent call - -> RegisteredMethod 'Normal - -> TimeoutSeconds - -> ByteString - -- ^ The body of the request - -> MetadataMap - -- ^ Metadata to send with the request - -> IO (Either GRPCIOError NormalRequestResult) +clientRequestParent :: + Client -> + -- | optional parent call + Maybe (ServerCall a) -> + RegisteredMethod 'Normal -> + TimeoutSeconds -> + -- | The body of the request + ByteString -> + -- | Metadata to send with the request + MetadataMap -> + IO (Either GRPCIOError NormalRequestResult) clientRequestParent cl@(clientCQ -> cq) p rm tm body initMeta = withClientCallParent cl rm tm p (fmap join . go) where go (unsafeCC -> c) = -- NB: the send and receive operations below *must* be in separate -- batches, or the client hangs when the server can't be reached. - runOps c cq + runOps + c + cq [ OpSendInitialMetadata initMeta , OpSendMessage body , OpSendCloseFromClient @@ -463,7 +516,9 @@ clientRequestParent cl@(clientCQ -> cq) p rm tm body initMeta = grpcDebug "clientRequest(R) : batch error sending." return $ Left x Right rs -> - runOps c cq + runOps + c + cq [ OpRecvInitialMetadata , OpRecvMessage , OpRecvStatusOnClient diff --git a/core/src/Network/GRPC/LowLevel/Client/Unregistered.hs b/core/src/Network/GRPC/LowLevel/Client/Unregistered.hs index 7fa5cbe3..4c3cc0e7 100644 --- a/core/src/Network/GRPC/LowLevel/Client/Unregistered.hs +++ b/core/src/Network/GRPC/LowLevel/Client/Unregistered.hs @@ -1,46 +1,56 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE ViewPatterns #-} module Network.GRPC.LowLevel.Client.Unregistered where -import Control.Arrow -import Control.Exception (bracket) -import Control.Monad (join) -import Data.ByteString (ByteString) -import Foreign.Ptr (nullPtr) -import qualified Network.GRPC.Unsafe as C -import qualified Network.GRPC.Unsafe.Constants as C -import qualified Network.GRPC.Unsafe.Time as C +import Control.Arrow +import Control.Exception (bracket) +import Control.Monad (join) +import Data.ByteString (ByteString) +import Foreign.Ptr (nullPtr) +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.Constants as C +import qualified Network.GRPC.Unsafe.Time as C -import Network.GRPC.LowLevel.Call -import Network.GRPC.LowLevel.Client (Client (..), - NormalRequestResult (..), - clientServerEndpoint, - compileNormalRequestResults) -import Network.GRPC.LowLevel.CompletionQueue (TimeoutSeconds) +import Network.GRPC.LowLevel.Call +import Network.GRPC.LowLevel.Client ( + Client (..), + NormalRequestResult (..), + clientServerEndpoint, + compileNormalRequestResults, + ) +import Network.GRPC.LowLevel.CompletionQueue (TimeoutSeconds) import qualified Network.GRPC.LowLevel.CompletionQueue.Unregistered as U -import Network.GRPC.LowLevel.GRPC -import Network.GRPC.LowLevel.Op +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.Op -- | Create a call on the client for an endpoint without using the -- method registration machinery. In practice, we'll probably only use the -- registered method version, but we include this for completeness and testing. -clientCreateCall :: Client - -> MethodName - -> TimeoutSeconds - -> IO (Either GRPCIOError ClientCall) +clientCreateCall :: + Client -> + MethodName -> + TimeoutSeconds -> + IO (Either GRPCIOError ClientCall) clientCreateCall Client{..} meth timeout = do let parentCall = C.Call nullPtr C.withDeadlineSeconds timeout $ \deadline -> do - U.channelCreateCall clientChannel parentCall C.propagateDefaults - clientCQ meth (clientServerEndpoint clientConfig) deadline + U.channelCreateCall + clientChannel + parentCall + C.propagateDefaults + clientCQ + meth + (clientServerEndpoint clientConfig) + deadline -withClientCall :: Client - -> MethodName - -> TimeoutSeconds - -> (ClientCall -> IO (Either GRPCIOError a)) - -> IO (Either GRPCIOError a) +withClientCall :: + Client -> + MethodName -> + TimeoutSeconds -> + (ClientCall -> IO (Either GRPCIOError a)) -> + IO (Either GRPCIOError a) withClientCall client method timeout f = bracket (clientCreateCall client method timeout) cleanup $ \case Left x -> return $ Left x @@ -53,27 +63,31 @@ withClientCall client method timeout f = -- | Makes a normal (non-streaming) request without needing to register a method -- first. Probably only useful for testing. -clientRequest :: Client - -> MethodName - -- ^ Method name, e.g. "/foo" - -> TimeoutSeconds - -- ^ "Number of seconds until request times out" - -> ByteString - -- ^ Request body. - -> MetadataMap - -- ^ Request metadata. - -> IO (Either GRPCIOError NormalRequestResult) +clientRequest :: + Client -> + -- | Method name, e.g. "/foo" + MethodName -> + -- | "Number of seconds until request times out" + TimeoutSeconds -> + -- | Request body. + ByteString -> + -- | Request metadata. + MetadataMap -> + IO (Either GRPCIOError NormalRequestResult) clientRequest cl@(clientCQ -> cq) meth tm body initMeta = join <$> withClientCall cl meth tm go where go (unsafeCC -> c) = do - results <- runOps c cq - [ OpSendInitialMetadata initMeta - , OpSendMessage body - , OpSendCloseFromClient - , OpRecvInitialMetadata - , OpRecvMessage - , OpRecvStatusOnClient - ] + results <- + runOps + c + cq + [ OpSendInitialMetadata initMeta + , OpSendMessage body + , OpSendCloseFromClient + , OpRecvInitialMetadata + , OpRecvMessage + , OpRecvStatusOnClient + ] grpcDebug "clientRequest(U): ops ran." return $ right compileNormalRequestResults results diff --git a/core/src/Network/GRPC/LowLevel/CompletionQueue.hs b/core/src/Network/GRPC/LowLevel/CompletionQueue.hs index a1edbc8d..d39d47c8 100644 --- a/core/src/Network/GRPC/LowLevel/CompletionQueue.hs +++ b/core/src/Network/GRPC/LowLevel/CompletionQueue.hs @@ -1,3 +1,11 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} + -- | Unlike most of the other internal low-level modules, we don't export -- everything here. There are several things in here that, if accessed, could -- cause race conditions, so we only expose functions that are thread safe. @@ -9,53 +17,46 @@ -- `Network.GRPC.LowLevel.CompletionQueue.Unregistered`. Type definitions and -- implementation details to both are kept in -- `Network.GRPC.LowLevel.CompletionQueue.Internal`. - -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE ViewPatterns #-} - -module Network.GRPC.LowLevel.CompletionQueue - ( CompletionQueue - , withCompletionQueue - , createCompletionQueue - , shutdownCompletionQueue - , pluck - , startBatch - , channelCreateCall - , TimeoutSeconds - , isEventSuccessful - , serverRegisterCompletionQueue - , serverShutdownAndNotify - , serverRequestCall - , newTag - ) +module Network.GRPC.LowLevel.CompletionQueue ( + CompletionQueue, + withCompletionQueue, + createCompletionQueue, + shutdownCompletionQueue, + pluck, + startBatch, + channelCreateCall, + TimeoutSeconds, + isEventSuccessful, + serverRegisterCompletionQueue, + serverShutdownAndNotify, + serverRequestCall, + newTag, +) where -import Control.Concurrent.STM.TVar (newTVarIO) -import Control.Exception (bracket) -import Control.Monad.Managed -import Control.Monad.Trans.Class (MonadTrans (lift)) -import Control.Monad.Trans.Except -import Data.IORef (newIORef) -import Data.List (intersperse) -import Foreign.Ptr (nullPtr) -import Foreign.Storable (peek) -import Network.GRPC.LowLevel.Call -import Network.GRPC.LowLevel.CompletionQueue.Internal -import Network.GRPC.LowLevel.GRPC -import qualified Network.GRPC.Unsafe as C -import qualified Network.GRPC.Unsafe.Constants as C -import qualified Network.GRPC.Unsafe.Metadata as C -import qualified Network.GRPC.Unsafe.Op as C -import qualified Network.GRPC.Unsafe.Time as C +import Control.Concurrent.STM.TVar (newTVarIO) +import Control.Exception (bracket) +import Control.Monad.Managed +import Control.Monad.Trans.Class (MonadTrans (lift)) +import Control.Monad.Trans.Except +import Data.IORef (newIORef) +import Data.List (intersperse) +import Foreign.Ptr (nullPtr) +import Foreign.Storable (peek) +import Network.GRPC.LowLevel.Call +import Network.GRPC.LowLevel.CompletionQueue.Internal +import Network.GRPC.LowLevel.GRPC +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.Constants as C +import qualified Network.GRPC.Unsafe.Metadata as C +import qualified Network.GRPC.Unsafe.Op as C +import qualified Network.GRPC.Unsafe.Time as C withCompletionQueue :: GRPC -> (CompletionQueue -> IO a) -> IO a -withCompletionQueue grpc = bracket (createCompletionQueue grpc) - shutdownCompletionQueue +withCompletionQueue grpc = + bracket + (createCompletionQueue grpc) + shutdownCompletionQueue createCompletionQueue :: GRPC -> IO CompletionQueue createCompletionQueue _ = do @@ -69,53 +70,94 @@ createCompletionQueue _ = do -- | Very simple wrapper around 'grpcCallStartBatch'. Throws 'GRPCIOShutdown' -- without calling 'grpcCallStartBatch' if the queue is shutting down. -- Throws 'CallError' if 'grpcCallStartBatch' returns a non-OK code. -startBatch :: CompletionQueue -> C.Call -> C.OpArray -> Int -> C.Tag - -> IO (Either GRPCIOError ()) +startBatch :: + CompletionQueue -> + C.Call -> + C.OpArray -> + Int -> + C.Tag -> + IO (Either GRPCIOError ()) startBatch cq call opArray opArraySize tag = - withPermission Push cq $ fmap throwIfCallError $ do - grpcDebug $ "startBatch: calling grpc_call_start_batch with pointers: " - ++ show call ++ " " ++ show opArray - res <- C.grpcCallStartBatch call opArray opArraySize tag C.reserved - grpcDebug "startBatch: grpc_call_start_batch call returned." - return res + withPermission Push cq $ fmap throwIfCallError $ do + grpcDebug $ + "startBatch: calling grpc_call_start_batch with pointers: " + ++ show call + ++ " " + ++ show opArray + res <- C.grpcCallStartBatch call opArray opArraySize tag C.reserved + grpcDebug "startBatch: grpc_call_start_batch call returned." + return res -channelCreateCall :: C.Channel - -> Maybe (ServerCall a) - -> C.PropagationMask - -> CompletionQueue - -> C.CallHandle - -> C.CTimeSpecPtr - -> IO (Either GRPCIOError ClientCall) +channelCreateCall :: + C.Channel -> + Maybe (ServerCall a) -> + C.PropagationMask -> + CompletionQueue -> + C.CallHandle -> + C.CTimeSpecPtr -> + IO (Either GRPCIOError ClientCall) channelCreateCall - chan parent mask cq@CompletionQueue{..} handle deadline = - withPermission Push cq $ do - let parentPtr = maybe (C.Call nullPtr) unsafeSC parent - grpcDebug $ "channelCreateCall: call with " - ++ concat (intersperse " " [show chan, show parentPtr, - show mask, - show unsafeCQ, show handle, - show deadline]) - call <- C.grpcChannelCreateRegisteredCall chan parentPtr mask unsafeCQ - handle deadline C.reserved - return $ Right $ ClientCall call + chan + parent + mask + cq@CompletionQueue{..} + handle + deadline = + withPermission Push cq $ do + let parentPtr = maybe (C.Call nullPtr) unsafeSC parent + grpcDebug $ + "channelCreateCall: call with " + ++ concat + ( intersperse + " " + [ show chan + , show parentPtr + , show mask + , show unsafeCQ + , show handle + , show deadline + ] + ) + call <- + C.grpcChannelCreateRegisteredCall + chan + parentPtr + mask + unsafeCQ + handle + deadline + C.reserved + return $ Right $ ClientCall call -- | Create the call object to handle a registered call. -serverRequestCall :: RegisteredMethod mt - -> C.Server - -> CompletionQueue -- ^ server CQ - -> CompletionQueue -- ^ call CQ - -> IO (Either GRPCIOError (ServerCall (MethodPayload mt))) +serverRequestCall :: + RegisteredMethod mt -> + C.Server -> + -- | server CQ + CompletionQueue -> + -- | call CQ + CompletionQueue -> + IO (Either GRPCIOError (ServerCall (MethodPayload mt))) serverRequestCall rm s scq ccq = -- NB: The method type dictates whether or not a payload is present, according -- to the payloadHandling function. We do not allocate a buffer for the -- payload when it is not present. withPermission Push scq . with allocs $ \(dead, call, pay, meta) -> withPermission Pluck scq $ do - md <- peek meta + md <- peek meta tag <- newTag scq dbug $ "got pluck permission, registering call for tag=" ++ show tag - ce <- C.grpcServerRequestRegisteredCall s (methodHandle rm) call dead md - pay (unsafeCQ ccq) (unsafeCQ scq) tag + ce <- + C.grpcServerRequestRegisteredCall + s + (methodHandle rm) + call + dead + md + pay + (unsafeCQ ccq) + (unsafeCQ scq) + tag runExceptT $ case ce of C.CallOk -> do ExceptT $ do @@ -124,20 +166,21 @@ serverRequestCall rm s scq ccq = return r lift $ ServerCall - <$> peek call - <*> return ccq - <*> C.getAllMetadataArray md - <*> extractPayload rm pay - <*> convertDeadline dead + <$> peek call + <*> return ccq + <*> C.getAllMetadataArray md + <*> extractPayload rm pay + <*> convertDeadline dead _ -> do lift $ dbug $ "Throwing callError: " ++ show ce throwE (GRPCIOCallError ce) where - allocs = (,,,) - <$> mgdPtr - <*> mgdPtr - <*> mgdPayload (methodType rm) - <*> managed C.withMetadataArrayPtr + allocs = + (,,,) + <$> mgdPtr + <*> mgdPtr + <*> mgdPayload (methodType rm) + <*> managed C.withMetadataArrayPtr dbug = grpcDebug . ("serverRequestCall(R): " ++) convertDeadline timeSpecPtr = C.timeSpec <$> peek timeSpecPtr diff --git a/core/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs b/core/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs index 8fd2d54a..4e2061c0 100644 --- a/core/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs +++ b/core/src/Network/GRPC/LowLevel/CompletionQueue/Internal.hs @@ -2,18 +2,22 @@ module Network.GRPC.LowLevel.CompletionQueue.Internal where -import Control.Concurrent.STM (atomically, retry, check) -import Control.Concurrent.STM.TVar (TVar, modifyTVar', readTVar, - writeTVar) -import Control.Exception (bracket) -import Control.Monad -import Data.IORef (IORef, atomicModifyIORef') -import Foreign.Ptr (nullPtr, plusPtr) -import Network.GRPC.LowLevel.GRPC -import qualified Network.GRPC.Unsafe as C +import Control.Concurrent.STM (atomically, check, retry) +import Control.Concurrent.STM.TVar ( + TVar, + modifyTVar', + readTVar, + writeTVar, + ) +import Control.Exception (bracket) +import Control.Monad +import Data.IORef (IORef, atomicModifyIORef') +import Foreign.Ptr (nullPtr, plusPtr) +import Network.GRPC.LowLevel.GRPC +import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.Constants as C -import qualified Network.GRPC.Unsafe.Time as C -import System.Timeout (timeout) +import qualified Network.GRPC.Unsafe.Time as C +import System.Timeout (timeout) -- NOTE: the concurrency requirements for a CompletionQueue are a little -- complicated. There are two read operations: next and pluck. We can either @@ -40,27 +44,28 @@ import System.Timeout (timeout) -- are used to wait for batches gRPC operations ('Op's) to finish running, as -- well as wait for various other operations, such as server shutdown, pinging, -- checking to see if we've been disconnected, and so forth. -data CompletionQueue = CompletionQueue {unsafeCQ :: C.CompletionQueue, - -- ^ All access to this field must be - -- guarded by a check of 'shuttingDown'. - currentPluckers :: TVar Int, - -- ^ Used to limit the number of - -- concurrent calls to pluck on this - -- queue. - -- The max value is set by gRPC in - -- 'C.maxCompletionQueuePluckers' - currentPushers :: TVar Int, - -- ^ Used to prevent new work from - -- being pushed onto the queue when - -- the queue begins to shut down. - shuttingDown :: TVar Bool, - -- ^ Used to prevent new pluck calls on - -- the queue when the queue begins to - -- shut down. - nextTag :: IORef Int - -- ^ Used to supply unique tags for work - -- items pushed onto the queue. - } +data CompletionQueue = CompletionQueue + { unsafeCQ :: C.CompletionQueue + -- ^ All access to this field must be + -- guarded by a check of 'shuttingDown'. + , currentPluckers :: TVar Int + -- ^ Used to limit the number of + -- concurrent calls to pluck on this + -- queue. + -- The max value is set by gRPC in + -- 'C.maxCompletionQueuePluckers' + , currentPushers :: TVar Int + -- ^ Used to prevent new work from + -- being pushed onto the queue when + -- the queue begins to shut down. + , shuttingDown :: TVar Bool + -- ^ Used to prevent new pluck calls on + -- the queue when the queue begins to + -- shut down. + , nextTag :: IORef Int + -- ^ Used to supply unique tags for work + -- items pushed onto the queue. + } instance Show CompletionQueue where show = show . unsafeCQ @@ -73,15 +78,16 @@ data CQOpType = Push | Pluck deriving (Show, Eq, Enum) -- practical perspective, that should be safe. newTag :: CompletionQueue -> IO C.Tag newTag CompletionQueue{..} = do - i <- atomicModifyIORef' nextTag (\i -> (i+1,i)) + i <- atomicModifyIORef' nextTag (\i -> (i + 1, i)) return $ C.Tag $ plusPtr nullPtr i -- | Safely brackets an operation that pushes work onto or plucks results from -- the given 'CompletionQueue'. -withPermission :: CQOpType - -> CompletionQueue - -> IO (Either GRPCIOError a) - -> IO (Either GRPCIOError a) +withPermission :: + CQOpType -> + CompletionQueue -> + IO (Either GRPCIOError a) -> + IO (Either GRPCIOError a) withPermission op cq act = bracket acquire release $ \gotResource -> if gotResource then act else return (Left GRPCIOShutdown) where @@ -93,8 +99,10 @@ withPermission op cq act = bracket acquire release $ \gotResource -> then writeTVar (getCount op cq) (currCount + 1) else retry return (not isShuttingDown) - release gotResource = when gotResource $ - atomically $ modifyTVar' (getCount op cq) (subtract 1) + release gotResource = + when gotResource $ + atomically $ + modifyTVar' (getCount op cq) (subtract 1) -- | Waits for the given number of seconds for the given tag to appear on the -- completion queue. Throws 'GRPCIOShutdown' if the completion queue is shutting @@ -102,17 +110,21 @@ withPermission op cq act = bracket acquire release $ \gotResource -> -- doing client ops, provide @Nothing@ and the pluck will automatically fail if -- the deadline associated with the 'ClientCall' expires. If plucking -- 'serverRequestCall', this will block forever unless a timeout is given. -pluck :: CompletionQueue -> C.Tag -> Maybe TimeoutSeconds - -> IO (Either GRPCIOError ()) +pluck :: + CompletionQueue -> + C.Tag -> + Maybe TimeoutSeconds -> + IO (Either GRPCIOError ()) pluck cq tag mwait = do grpcDebug $ "pluck: called with tag=" ++ show tag ++ ",mwait=" ++ show mwait withPermission Pluck cq $ pluck' cq tag mwait -- Variant of pluck' which assumes pluck permission has been granted. -pluck' :: CompletionQueue - -> C.Tag - -> Maybe TimeoutSeconds - -> IO (Either GRPCIOError ()) +pluck' :: + CompletionQueue -> + C.Tag -> + Maybe TimeoutSeconds -> + IO (Either GRPCIOError ()) pluck' CompletionQueue{..} tag mwait = maybe C.withInfiniteDeadline C.withDeadlineSeconds mwait $ \dead -> do grpcDebug $ "pluck: blocking on grpc_completion_queue_pluck for tag=" ++ show tag @@ -135,14 +147,14 @@ isEventSuccessful (C.Event C.OpComplete True _) = True isEventSuccessful _ = False maxWorkPushers :: Int -maxWorkPushers = 100 --TODO: figure out what this should be. +maxWorkPushers = 100 -- TODO: figure out what this should be. getCount :: CQOpType -> CompletionQueue -> TVar Int -getCount Push = currentPushers +getCount Push = currentPushers getCount Pluck = currentPluckers getLimit :: CQOpType -> Int -getLimit Push = maxWorkPushers +getLimit Push = maxWorkPushers getLimit Pluck = C.maxCompletionQueuePluckers -- | Shuts down the completion queue. See the comment above 'CompletionQueue' @@ -153,23 +165,23 @@ shutdownCompletionQueue :: CompletionQueue -> IO (Either GRPCIOError ()) shutdownCompletionQueue CompletionQueue{..} = do atomically $ writeTVar shuttingDown True atomically $ do - readTVar currentPushers >>= check . (==0) - readTVar currentPluckers >>= check . (==0) - --drain the queue + readTVar currentPushers >>= check . (== 0) + readTVar currentPluckers >>= check . (== 0) + -- drain the queue C.grpcCompletionQueueShutdown unsafeCQ - loopRes <- timeout (5*10^(6::Int)) drainLoop + loopRes <- timeout (5 * 10 ^ (6 :: Int)) drainLoop grpcDebug $ "Got CQ loop shutdown result of: " ++ show loopRes case loopRes of Nothing -> return $ Left GRPCIOShutdownFailure Just () -> C.grpcCompletionQueueDestroy unsafeCQ >> return (Right ()) - - where drainLoop :: IO () - drainLoop = do - grpcDebug "drainLoop: before next() call" - ev <- C.withDeadlineSeconds 1 $ \deadline -> - C.grpcCompletionQueuePluck unsafeCQ C.noTag deadline C.reserved - grpcDebug $ "drainLoop: next() call got " ++ show ev - case C.eventCompletionType ev of - C.QueueShutdown -> return () - C.QueueTimeout -> drainLoop - C.OpComplete -> drainLoop + where + drainLoop :: IO () + drainLoop = do + grpcDebug "drainLoop: before next() call" + ev <- C.withDeadlineSeconds 1 $ \deadline -> + C.grpcCompletionQueuePluck unsafeCQ C.noTag deadline C.reserved + grpcDebug $ "drainLoop: next() call got " ++ show ev + case C.eventCompletionType ev of + C.QueueShutdown -> return () + C.QueueTimeout -> drainLoop + C.OpComplete -> drainLoop diff --git a/core/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs b/core/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs index f73c4c0d..bce6154c 100644 --- a/core/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs +++ b/core/src/Network/GRPC/LowLevel/CompletionQueue/Unregistered.hs @@ -1,52 +1,63 @@ -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE MultiWayIf #-} -{-# LANGUAGE PolyKinds #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiWayIf #-} +{-# LANGUAGE PolyKinds #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} module Network.GRPC.LowLevel.CompletionQueue.Unregistered where -import Control.Monad.Managed -import Control.Monad.Trans.Class (MonadTrans (lift)) -import Control.Monad.Trans.Except -import Foreign.Storable (peek) -import Network.GRPC.LowLevel.Call -import qualified Network.GRPC.LowLevel.Call.Unregistered as U -import Network.GRPC.LowLevel.CompletionQueue.Internal -import Network.GRPC.LowLevel.GRPC -import qualified Network.GRPC.Unsafe as C -import qualified Network.GRPC.Unsafe.Constants as C -import qualified Network.GRPC.Unsafe.Metadata as C -import qualified Network.GRPC.Unsafe.Time as C +import Control.Monad.Managed +import Control.Monad.Trans.Class (MonadTrans (lift)) +import Control.Monad.Trans.Except +import Foreign.Storable (peek) +import Network.GRPC.LowLevel.Call +import qualified Network.GRPC.LowLevel.Call.Unregistered as U +import Network.GRPC.LowLevel.CompletionQueue.Internal +import Network.GRPC.LowLevel.GRPC +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.Constants as C +import qualified Network.GRPC.Unsafe.Metadata as C +import qualified Network.GRPC.Unsafe.Time as C -channelCreateCall :: C.Channel - -> C.Call - -> C.PropagationMask - -> CompletionQueue - -> MethodName - -> Endpoint - -> C.CTimeSpecPtr - -> IO (Either GRPCIOError ClientCall) +channelCreateCall :: + C.Channel -> + C.Call -> + C.PropagationMask -> + CompletionQueue -> + MethodName -> + Endpoint -> + C.CTimeSpecPtr -> + IO (Either GRPCIOError ClientCall) channelCreateCall chan parent mask cq@CompletionQueue{..} meth endpt deadline = withPermission Push cq $ do - call <- C.grpcChannelCreateCall chan parent mask unsafeCQ - (unMethodName meth) (unEndpoint endpt) deadline C.reserved + call <- + C.grpcChannelCreateCall + chan + parent + mask + unsafeCQ + (unMethodName meth) + (unEndpoint endpt) + deadline + C.reserved return $ Right $ ClientCall call - -serverRequestCall :: C.Server - -> CompletionQueue -- ^ server CQ / notification CQ - -> CompletionQueue -- ^ call CQ - -> IO (Either GRPCIOError U.ServerCall) +serverRequestCall :: + C.Server -> + -- | server CQ / notification CQ + CompletionQueue -> + -- | call CQ + CompletionQueue -> + IO (Either GRPCIOError U.ServerCall) serverRequestCall s scq ccq = withPermission Push scq . with allocs $ \(call, meta, cd) -> withPermission Pluck scq $ do - md <- peek meta + md <- peek meta tag <- newTag scq dbug $ "got pluck permission, registering call for tag=" ++ show tag - ce <- C.grpcServerRequestCall s call cd md (unsafeCQ ccq) (unsafeCQ scq) tag + ce <- C.grpcServerRequestCall s call cd md (unsafeCQ ccq) (unsafeCQ scq) tag runExceptT $ case ce of C.CallOk -> do ExceptT $ do @@ -66,14 +77,15 @@ serverRequestCall s scq ccq = <*> return ccq <*> C.getAllMetadataArray md <*> (C.timeSpec <$> C.callDetailsGetDeadline cd) - <*> (MethodName <$> C.callDetailsGetMethod cd) - <*> (Host <$> C.callDetailsGetHost cd) + <*> (MethodName <$> C.callDetailsGetMethod cd) + <*> (Host <$> C.callDetailsGetHost cd) _ -> do lift $ dbug $ "Throwing callError: " ++ show ce throwE $ GRPCIOCallError ce where - allocs = (,,) - <$> mgdPtr - <*> managed C.withMetadataArrayPtr - <*> managed C.withCallDetails + allocs = + (,,) + <$> mgdPtr + <*> managed C.withMetadataArrayPtr + <*> managed C.withCallDetails dbug = grpcDebug . ("serverRequestCall(U): " ++) diff --git a/core/src/Network/GRPC/LowLevel/GRPC.hs b/core/src/Network/GRPC/LowLevel/GRPC.hs index e0705441..1f1a5320 100644 --- a/core/src/Network/GRPC/LowLevel/GRPC.hs +++ b/core/src/Network/GRPC/LowLevel/GRPC.hs @@ -1,28 +1,27 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE StandaloneDeriving #-} -{-# LANGUAGE MultiParamTypeClasses #-} -{-# LANGUAGE StandaloneDeriving #-} - -module Network.GRPC.LowLevel.GRPC( - GRPC -, withGRPC -, startGRPC -, stopGRPC -, GRPCIOError(..) -, throwIfCallError -, grpcDebug -, grpcDebug' -, threadDelaySecs -, MetadataMap(..) -, C.StatusDetails(..) +module Network.GRPC.LowLevel.GRPC ( + GRPC, + withGRPC, + startGRPC, + stopGRPC, + GRPCIOError (..), + throwIfCallError, + grpcDebug, + grpcDebug', + threadDelaySecs, + MetadataMap (..), + C.StatusDetails (..), ) where -import Control.Concurrent (threadDelay, myThreadId) -import Control.Exception -import Data.Functor (($>)) -import Data.Typeable -import Network.GRPC.LowLevel.GRPC.MetadataMap (MetadataMap(..)) -import qualified Network.GRPC.Unsafe as C +import Control.Concurrent (myThreadId, threadDelay) +import Control.Exception +import Data.Functor (($>)) +import Data.Typeable +import Network.GRPC.LowLevel.GRPC.MetadataMap (MetadataMap (..)) +import qualified Network.GRPC.Unsafe as C import qualified Network.GRPC.Unsafe.Op as C -- | Functions as a proof that the gRPC core has been started. The gRPC core @@ -31,45 +30,46 @@ import qualified Network.GRPC.Unsafe.Op as C data GRPC = GRPC withGRPC :: (GRPC -> IO a) -> IO a -withGRPC = bracket startGRPC stopGRPC +withGRPC = bracket startGRPC stopGRPC -- | Start gRPC core and obtain a 'GRPC' witness. This function does not perform --- any cleanup once the gRPC server is no longer needed. +-- any cleanup once the gRPC server is no longer needed. -- --- Where possible, consider using 'withGRPC' which handles shutdown of gRPC --- automatically with 'bracket'. -startGRPC :: IO GRPC +-- Where possible, consider using 'withGRPC' which handles shutdown of gRPC +-- automatically with 'bracket'. +startGRPC :: IO GRPC startGRPC = C.grpcInit $> GRPC --- | Shutdown gRPC core given a 'GRPC' witnessing that gRPC core has been +-- | Shutdown gRPC core given a 'GRPC' witnessing that gRPC core has been -- initialized. stopGRPC :: GRPC -> IO () -stopGRPC GRPC = do +stopGRPC GRPC = do grpcDebug "withGRPC: shutting down" C.grpcShutdown -- | Describes all errors that can occur while running a GRPC-related IO -- action. -data GRPCIOError = GRPCIOCallError C.CallError - -- ^ Errors that can occur while the call is in flight. These - -- errors come from the core gRPC library directly. - | GRPCIOTimeout - -- ^ Indicates that we timed out while waiting for an - -- operation to complete on the 'CompletionQueue'. - | GRPCIOShutdown - -- ^ Indicates that the 'CompletionQueue' is shutting down - -- and no more work can be processed. This can happen if the - -- client or server is shutting down. - | GRPCIOShutdownFailure - -- ^ Thrown if a 'CompletionQueue' fails to shut down in a - -- reasonable amount of time. - | GRPCIOUnknownError - | GRPCIOBadStatusCode C.StatusCode C.StatusDetails - - | GRPCIODecodeError String - | GRPCIOInternalUnexpectedRecv String -- debugging description - | GRPCIOHandlerException String +data GRPCIOError + = -- | Errors that can occur while the call is in flight. These + -- errors come from the core gRPC library directly. + GRPCIOCallError C.CallError + | -- | Indicates that we timed out while waiting for an + -- operation to complete on the 'CompletionQueue'. + GRPCIOTimeout + | -- | Indicates that the 'CompletionQueue' is shutting down + -- and no more work can be processed. This can happen if the + -- client or server is shutting down. + GRPCIOShutdown + | -- | Thrown if a 'CompletionQueue' fails to shut down in a + -- reasonable amount of time. + GRPCIOShutdownFailure + | GRPCIOUnknownError + | GRPCIOBadStatusCode C.StatusCode C.StatusDetails + | GRPCIODecodeError String + | GRPCIOInternalUnexpectedRecv String -- debugging description + | GRPCIOHandlerException String deriving (Eq, Show, Typeable) + instance Exception GRPCIOError throwIfCallError :: C.CallError -> Either GRPCIOError () @@ -91,4 +91,4 @@ grpcDebug' str = do putStrLn $ "[" ++ show tid ++ "]: " ++ str threadDelaySecs :: Int -> IO () -threadDelaySecs = threadDelay . (* 10^(6::Int)) +threadDelaySecs = threadDelay . (* 10 ^ (6 :: Int)) diff --git a/core/src/Network/GRPC/LowLevel/GRPC/MetadataMap.hs b/core/src/Network/GRPC/LowLevel/GRPC/MetadataMap.hs index 698f26cf..ddf7083a 100644 --- a/core/src/Network/GRPC/LowLevel/GRPC/MetadataMap.hs +++ b/core/src/Network/GRPC/LowLevel/GRPC/MetadataMap.hs @@ -5,34 +5,32 @@ module Network.GRPC.LowLevel.GRPC.MetadataMap where import Data.ByteString (ByteString) -import Data.Function (on) import Data.Data (Data) -import Data.Typeable (Typeable) -import GHC.Exts (IsList(..)) -import Data.List (sortBy, groupBy) -import Data.Ord (comparing) +import Data.Function (on) +import Data.List (groupBy, sortBy) import qualified Data.List.NonEmpty as NE import qualified Data.Map.Strict as M +import Data.Ord (comparing) +import Data.Typeable (Typeable) +import GHC.Exts (IsList (..)) -{- | Represents metadata for a given RPC, consisting of key-value pairs (often - referred to as "GRPC custom metadata headers"). - - Keys are allowed to be repeated, with the 'last' value element (i.e., the - last-presented) usually taken as the value for that key (see 'lookupLast' and - 'lookupAll'). - - Since repeated keys are unlikely in practice, the 'IsList' instance for - 'MetadataMap' uses key-value pairs as items, and treats duplicates - appropriately. - - >>> lookupAll "k1" (fromList [("k1","x"), ("k2", "z"), ("k1", "y")]) - Just ("x" :| ["y"]) - - >>> lookupLast "k1" (fromList [("k1","x"), ("k2", "z"), ("k1", "y")]) - Just "y" --} - -newtype MetadataMap = MetadataMap +-- | Represents metadata for a given RPC, consisting of key-value pairs (often +-- referred to as "GRPC custom metadata headers"). +-- +-- Keys are allowed to be repeated, with the 'last' value element (i.e., the +-- last-presented) usually taken as the value for that key (see 'lookupLast' and +-- 'lookupAll'). +-- +-- Since repeated keys are unlikely in practice, the 'IsList' instance for +-- 'MetadataMap' uses key-value pairs as items, and treats duplicates +-- appropriately. +-- +-- >>> lookupAll "k1" (fromList [("k1","x"), ("k2", "z"), ("k1", "y")]) +-- Just ("x" :| ["y"]) +-- +-- >>> lookupLast "k1" (fromList [("k1","x"), ("k2", "z"), ("k1", "y")]) +-- Just "y" +newtype MetadataMap = MetadataMap {unMap :: M.Map ByteString [ByteString]} deriving (Data, Eq, Ord, Typeable) @@ -48,15 +46,17 @@ instance Monoid MetadataMap where instance IsList MetadataMap where type Item MetadataMap = (ByteString, ByteString) - fromList = MetadataMap - . M.fromList - . map (\xs -> ((fst . head) xs, map snd xs)) - . groupBy ((==) `on` fst) - . sortBy (comparing fst) - toList = concatMap (\(k,vs) -> map (k,) vs) - . map (fmap toList) - . M.toList - . unMap + fromList = + MetadataMap + . M.fromList + . map (\xs -> ((fst . head) xs, map snd xs)) + . groupBy ((==) `on` fst) + . sortBy (comparing fst) + toList = + concatMap (\(k, vs) -> map (k,) vs) + . map (fmap toList) + . M.toList + . unMap -- | Obtain all header values for a given header key, in presentation order. lookupAll :: ByteString -> MetadataMap -> Maybe (NE.NonEmpty ByteString) diff --git a/core/src/Network/GRPC/LowLevel/Op.hs b/core/src/Network/GRPC/LowLevel/Op.hs index 22fead45..e21a145e 100644 --- a/core/src/Network/GRPC/LowLevel/Op.hs +++ b/core/src/Network/GRPC/LowLevel/Op.hs @@ -1,48 +1,49 @@ {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE ViewPatterns #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ViewPatterns #-} module Network.GRPC.LowLevel.Op where -import Control.Exception -import Control.Monad -import Control.Monad.Trans.Except -import Data.ByteString (ByteString) -import qualified Data.ByteString as B -import Data.Maybe (catMaybes) -import Foreign.C.Types (CInt) -import Foreign.Marshal.Alloc (free, malloc) -import Foreign.Ptr (Ptr, nullPtr) -import Foreign.Storable (peek) -import Network.GRPC.LowLevel.CompletionQueue -import Network.GRPC.LowLevel.GRPC -import qualified Network.GRPC.Unsafe as C (Call) -import qualified Network.GRPC.Unsafe.ByteBuffer as C -import qualified Network.GRPC.Unsafe.Metadata as C -import qualified Network.GRPC.Unsafe.Op as C -import qualified Network.GRPC.Unsafe.Slice as C -import Network.GRPC.Unsafe.Slice (Slice) +import Control.Exception +import Control.Monad +import Control.Monad.Trans.Except +import Data.ByteString (ByteString) +import qualified Data.ByteString as B +import Data.Maybe (catMaybes) +import Foreign.C.Types (CInt) +import Foreign.Marshal.Alloc (free, malloc) +import Foreign.Ptr (Ptr, nullPtr) +import Foreign.Storable (peek) +import Network.GRPC.LowLevel.CompletionQueue +import Network.GRPC.LowLevel.GRPC +import qualified Network.GRPC.Unsafe as C (Call) +import qualified Network.GRPC.Unsafe.ByteBuffer as C +import qualified Network.GRPC.Unsafe.Metadata as C +import qualified Network.GRPC.Unsafe.Op as C +import Network.GRPC.Unsafe.Slice (Slice) +import qualified Network.GRPC.Unsafe.Slice as C -- | Sum describing all possible send and receive operations that can be batched -- and executed by gRPC. Usually these are processed in a handful of -- combinations depending on the 'MethodType' of the call being run. -data Op = OpSendInitialMetadata MetadataMap - | OpSendMessage B.ByteString - | OpSendCloseFromClient - | OpSendStatusFromServer MetadataMap C.StatusCode StatusDetails - | OpRecvInitialMetadata - | OpRecvMessage - | OpRecvStatusOnClient - | OpRecvCloseOnServer - deriving (Show) +data Op + = OpSendInitialMetadata MetadataMap + | OpSendMessage B.ByteString + | OpSendCloseFromClient + | OpSendStatusFromServer MetadataMap C.StatusCode StatusDetails + | OpRecvInitialMetadata + | OpRecvMessage + | OpRecvStatusOnClient + | OpRecvCloseOnServer + deriving (Show) -- | Container holding the pointers to the C and gRPC data needed to execute the -- corresponding 'Op'. These are obviously unsafe, and should only be used with -- 'withOpContexts'. -data OpContext = - OpSendInitialMetadataContext C.MetadataKeyValPtr Int +data OpContext + = OpSendInitialMetadataContext C.MetadataKeyValPtr Int | OpSendMessageContext (C.ByteBuffer, C.Slice) | OpSendCloseFromClientContext | OpSendStatusFromServerContext C.MetadataKeyValPtr Int C.StatusCode Slice @@ -50,7 +51,7 @@ data OpContext = | OpRecvMessageContext (Ptr C.ByteBuffer) | OpRecvStatusOnClientContext (Ptr C.MetadataArray) (Ptr C.StatusCode) Slice | OpRecvCloseOnServerContext (Ptr CInt) - deriving Show + deriving (Show) -- | Length we pass to gRPC for receiving status details -- when processing 'OpRecvStatusOnClient'. It appears that gRPC actually ignores @@ -67,9 +68,9 @@ createOpContext (OpSendMessage bs) = createOpContext (OpSendCloseFromClient) = return OpSendCloseFromClientContext createOpContext (OpSendStatusFromServer m code (StatusDetails str)) = uncurry OpSendStatusFromServerContext - <$> C.createMetadata m - <*> return code - <*> C.byteStringToSlice str + <$> C.createMetadata m + <*> return code + <*> C.byteStringToSlice str createOpContext OpRecvInitialMetadata = fmap OpRecvInitialMetadataContext C.metadataArrayCreate createOpContext OpRecvMessage = @@ -87,12 +88,12 @@ createOpContext OpRecvCloseOnServer = setOpArray :: C.OpArray -> Int -> OpContext -> IO () setOpArray arr i (OpSendInitialMetadataContext kvs l) = C.opSendInitialMetadata arr i kvs l -setOpArray arr i (OpSendMessageContext (bb,_)) = +setOpArray arr i (OpSendMessageContext (bb, _)) = C.opSendMessage arr i bb setOpArray arr i OpSendCloseFromClientContext = C.opSendCloseClient arr i setOpArray arr i (OpSendStatusFromServerContext kvs l code details) = - C.opSendStatusServer arr i l kvs code details + C.opSendStatusServer arr i l kvs code details setOpArray arr i (OpRecvInitialMetadataContext pmetadata) = C.opRecvInitialMetadata arr i pmetadata setOpArray arr i (OpRecvMessageContext pbb) = @@ -120,29 +121,33 @@ freeOpContext (OpRecvStatusOnClientContext metadata pcode slice) = do C.freeSlice slice freeOpContext (OpRecvCloseOnServerContext pcancelled) = grpcDebug ("freeOpContext: freeing pcancelled: " ++ show pcancelled) - >> free pcancelled + >> free pcancelled -- | Allocates an `OpArray` and a list of `OpContext`s from the given list of -- `Op`s. withOpArrayAndCtxts :: [Op] -> ((C.OpArray, [OpContext]) -> IO a) -> IO a withOpArrayAndCtxts ops = bracket setup teardown - where setup = do ctxts <- mapM createOpContext ops - let l = length ops - arr <- C.opArrayCreate l - sequence_ $ zipWith (setOpArray arr) [0..l-1] ctxts - return (arr, ctxts) - teardown (arr, ctxts) = do C.opArrayDestroy arr (length ctxts) - mapM_ freeOpContext ctxts + where + setup = do + ctxts <- mapM createOpContext ops + let l = length ops + arr <- C.opArrayCreate l + sequence_ $ zipWith (setOpArray arr) [0 .. l - 1] ctxts + return (arr, ctxts) + teardown (arr, ctxts) = do + C.opArrayDestroy arr (length ctxts) + mapM_ freeOpContext ctxts -- | Container holding GC-managed results for 'Op's which receive data. -data OpRecvResult = - OpRecvInitialMetadataResult MetadataMap - | OpRecvMessageResult (Maybe B.ByteString) - -- ^ If a streaming call is in progress and the stream terminates normally, +data OpRecvResult + = OpRecvInitialMetadataResult MetadataMap + | -- | If a streaming call is in progress and the stream terminates normally, -- or If the client or server dies, we might not receive a response body, in -- which case this will be 'Nothing'. + OpRecvMessageResult (Maybe B.ByteString) | OpRecvStatusOnClientResult MetadataMap C.StatusCode B.ByteString - | OpRecvCloseOnServerResult Bool -- ^ True if call was cancelled. + | -- | True if call was cancelled. + OpRecvCloseOnServerResult Bool deriving (Show) -- | For the given 'OpContext', if the 'Op' receives data, copies the data out @@ -158,11 +163,13 @@ resultFromOpContext (OpRecvMessageContext pbb) = do grpcDebug "resultFromOpContext: OpRecvMessageContext" bb@(C.ByteBuffer bbptr) <- peek pbb if bbptr == nullPtr - then do grpcDebug "resultFromOpContext: WARNING: got empty message." - return $ Just $ OpRecvMessageResult Nothing - else do bs <- C.copyByteBufferToByteString bb - grpcDebug $ "resultFromOpContext: bb copied: " ++ show bs - return $ Just $ OpRecvMessageResult (Just bs) + then do + grpcDebug "resultFromOpContext: WARNING: got empty message." + return $ Just $ OpRecvMessageResult Nothing + else do + bs <- C.copyByteBufferToByteString bb + grpcDebug $ "resultFromOpContext: bb copied: " ++ show bs + return $ Just $ OpRecvMessageResult (Just bs) resultFromOpContext (OpRecvStatusOnClientContext pmetadata pcode pstr) = do grpcDebug "resultFromOpContext: OpRecvStatusOnClientContext" metadata <- peek pmetadata @@ -172,8 +179,10 @@ resultFromOpContext (OpRecvStatusOnClientContext pmetadata pcode pstr) = do return $ Just $ OpRecvStatusOnClientResult metadataMap code statusInfo resultFromOpContext (OpRecvCloseOnServerContext pcancelled) = do grpcDebug "resultFromOpContext: OpRecvCloseOnServerContext" - cancelled <- fmap (\x -> if x > 0 then True else False) - (peek pcancelled) + cancelled <- + fmap + (\x -> if x > 0 then True else False) + (peek pcancelled) return $ Just $ OpRecvCloseOnServerResult cancelled resultFromOpContext _ = do grpcDebug "resultFromOpContext: saw non-result op type." @@ -197,70 +206,74 @@ resultFromOpContext _ = do -- GRPC_CALL_ERROR_TOO_MANY_OPERATIONS error if we use the same 'Op' twice in -- the same batch, so we might want to change the list to a set. I don't think -- order matters within a batch. Need to check. -runOps :: C.Call - -- ^ 'Call' that this batch is associated with. One call can be - -- associated with many batches. - -> CompletionQueue - -- ^ Queue on which our tag will be placed once our ops are done - -- running. - -> [Op] - -- ^ The list of 'Op's to execute. - -> IO (Either GRPCIOError [OpRecvResult]) +runOps :: + -- | 'Call' that this batch is associated with. One call can be + -- associated with many batches. + C.Call -> + -- | Queue on which our tag will be placed once our ops are done + -- running. + CompletionQueue -> + -- | The list of 'Op's to execute. + [Op] -> + IO (Either GRPCIOError [OpRecvResult]) runOps call cq ops = - let l = length ops in - -- It is crucial to mask exceptions here. If we don’t do this, we can - -- run into the following situation: - -- - -- 1. We allocate an OpContext, e.g., OpRecvMessageContext and the corresponding ByteBuffer. - -- 2. We pass the buffer to gRPC in startBatch. - -- 3. If we now get an exception we will free the ByteBuffer. - -- 4. gRPC can now end up writing to the freed ByteBuffer and we get a heap corruption. - withOpArrayAndCtxts ops $ \(opArray, contexts) -> mask_ $ do - grpcDebug $ "runOps: allocated op contexts: " ++ show contexts - tag <- newTag cq - grpcDebug $ "runOps: tag: " ++ show tag - callError <- startBatch cq call opArray l tag - grpcDebug $ "runOps: called start_batch. callError: " - ++ (show callError) - case callError of - Left x -> return $ Left x - Right () -> do - ev <- pluck cq tag Nothing - grpcDebug $ "runOps: pluck returned " ++ show ev - case ev of - Right () -> do - grpcDebug "runOps: got good op; starting." - fmap (Right . catMaybes) $ mapM resultFromOpContext contexts - Left err -> return $ Left err + let l = length ops + in -- It is crucial to mask exceptions here. If we don’t do this, we can + -- run into the following situation: + -- + -- 1. We allocate an OpContext, e.g., OpRecvMessageContext and the corresponding ByteBuffer. + -- 2. We pass the buffer to gRPC in startBatch. + -- 3. If we now get an exception we will free the ByteBuffer. + -- 4. gRPC can now end up writing to the freed ByteBuffer and we get a heap corruption. + withOpArrayAndCtxts ops $ \(opArray, contexts) -> mask_ $ do + grpcDebug $ "runOps: allocated op contexts: " ++ show contexts + tag <- newTag cq + grpcDebug $ "runOps: tag: " ++ show tag + callError <- startBatch cq call opArray l tag + grpcDebug $ + "runOps: called start_batch. callError: " + ++ (show callError) + case callError of + Left x -> return $ Left x + Right () -> do + ev <- pluck cq tag Nothing + grpcDebug $ "runOps: pluck returned " ++ show ev + case ev of + Right () -> do + grpcDebug "runOps: got good op; starting." + fmap (Right . catMaybes) $ mapM resultFromOpContext contexts + Left err -> return $ Left err -runOps' :: C.Call - -> CompletionQueue - -> [Op] - -> ExceptT GRPCIOError IO [OpRecvResult] +runOps' :: + C.Call -> + CompletionQueue -> + [Op] -> + ExceptT GRPCIOError IO [OpRecvResult] runOps' c cq = ExceptT . runOps c cq -- | If response status info is present in the given 'OpRecvResult's, returns -- a tuple of trailing metadata, status code, and status details. -extractStatusInfo :: [OpRecvResult] - -> Maybe (MetadataMap, C.StatusCode, B.ByteString) +extractStatusInfo :: + [OpRecvResult] -> + Maybe (MetadataMap, C.StatusCode, B.ByteString) extractStatusInfo [] = Nothing -extractStatusInfo (OpRecvStatusOnClientResult meta code details:_) = +extractStatusInfo (OpRecvStatusOnClientResult meta code details : _) = Just (meta, code, details) -extractStatusInfo (_:xs) = extractStatusInfo xs +extractStatusInfo (_ : xs) = extractStatusInfo xs -------------------------------------------------------------------------------- -- Types and helpers for common ops batches -type SendSingle a - = C.Call - -> CompletionQueue - -> a - -> ExceptT GRPCIOError IO () +type SendSingle a = + C.Call -> + CompletionQueue -> + a -> + ExceptT GRPCIOError IO () -type RecvSingle a - = C.Call - -> CompletionQueue - -> ExceptT GRPCIOError IO a +type RecvSingle a = + C.Call -> + CompletionQueue -> + ExceptT GRPCIOError IO a pattern RecvMsgRslt :: Maybe ByteString -> Either a [OpRecvResult] pattern RecvMsgRslt mmsg <- Right [OpRecvMessageResult mmsg] @@ -276,28 +289,31 @@ sendStatusFromServer c cq (md, st, ds) = sendSingle c cq (OpSendStatusFromServer md st ds) recvInitialMessage :: RecvSingle ByteString -recvInitialMessage c cq = ExceptT (streamRecvPrim c cq ) >>= \case - Nothing -> throwE (GRPCIOInternalUnexpectedRecv "recvInitialMessage: no message.") - Just bs -> return bs +recvInitialMessage c cq = + ExceptT (streamRecvPrim c cq) >>= \case + Nothing -> throwE (GRPCIOInternalUnexpectedRecv "recvInitialMessage: no message.") + Just bs -> return bs recvInitialMetadata :: RecvSingle MetadataMap -recvInitialMetadata c cq = runOps' c cq [OpRecvInitialMetadata] >>= \case - [OpRecvInitialMetadataResult md] - -> return md - _ -> throwE (GRPCIOInternalUnexpectedRecv "recvInitialMetadata") +recvInitialMetadata c cq = + runOps' c cq [OpRecvInitialMetadata] >>= \case + [OpRecvInitialMetadataResult md] -> + return md + _ -> throwE (GRPCIOInternalUnexpectedRecv "recvInitialMetadata") recvInitialMsgMD :: RecvSingle (Maybe ByteString, MetadataMap) -recvInitialMsgMD c cq = runOps' c cq [OpRecvInitialMetadata, OpRecvMessage] >>= \case - [ OpRecvInitialMetadataResult md, OpRecvMessageResult mmsg] - -> return (mmsg, md) - _ -> throwE (GRPCIOInternalUnexpectedRecv "recvInitialMsgMD") +recvInitialMsgMD c cq = + runOps' c cq [OpRecvInitialMetadata, OpRecvMessage] >>= \case + [OpRecvInitialMetadataResult md, OpRecvMessageResult mmsg] -> + return (mmsg, md) + _ -> throwE (GRPCIOInternalUnexpectedRecv "recvInitialMsgMD") recvStatusOnClient :: RecvSingle (MetadataMap, C.StatusCode, StatusDetails) -recvStatusOnClient c cq = runOps' c cq [OpRecvStatusOnClient] >>= \case - [OpRecvStatusOnClientResult md st ds] - -> return (md, st, StatusDetails ds) - _ -> throwE (GRPCIOInternalUnexpectedRecv "recvStatusOnClient") - +recvStatusOnClient c cq = + runOps' c cq [OpRecvStatusOnClient] >>= \case + [OpRecvStatusOnClientResult md st ds] -> + return (md, st, StatusDetails ds) + _ -> throwE (GRPCIOInternalUnexpectedRecv "recvStatusOnClient") -------------------------------------------------------------------------------- -- Streaming types and helpers @@ -307,21 +323,21 @@ streamRecvPrim :: C.Call -> CompletionQueue -> StreamRecv ByteString streamRecvPrim c cq = f <$> runOps c cq [OpRecvMessage] where f (RecvMsgRslt mmsg) = Right mmsg - f Right{} = Left (GRPCIOInternalUnexpectedRecv "streamRecvPrim") - f (Left e) = Left e + f Right{} = Left (GRPCIOInternalUnexpectedRecv "streamRecvPrim") + f (Left e) = Left e type StreamSend a = a -> IO (Either GRPCIOError ()) streamSendPrim :: C.Call -> CompletionQueue -> StreamSend ByteString streamSendPrim c cq bs = f <$> runOps c cq [OpSendMessage bs] where f (Right []) = Right () - f Right{} = Left (GRPCIOInternalUnexpectedRecv "streamSendPrim") - f (Left e) = Left e + f Right{} = Left (GRPCIOInternalUnexpectedRecv "streamSendPrim") + f (Left e) = Left e type WritesDone = IO (Either GRPCIOError ()) writesDonePrim :: C.Call -> CompletionQueue -> WritesDone writesDonePrim c cq = f <$> runOps c cq [OpSendCloseFromClient] where f (Right []) = Right () - f Right{} = Left (GRPCIOInternalUnexpectedRecv "writesDonePrim") - f (Left e) = Left e + f Right{} = Left (GRPCIOInternalUnexpectedRecv "writesDonePrim") + f (Left e) = Left e diff --git a/core/src/Network/GRPC/LowLevel/Server.hs b/core/src/Network/GRPC/LowLevel/Server.hs index bb0290c6..93036d20 100644 --- a/core/src/Network/GRPC/LowLevel/Server.hs +++ b/core/src/Network/GRPC/LowLevel/Server.hs @@ -1,75 +1,85 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE FlexibleContexts #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE TupleSections #-} -{-# LANGUAGE ViewPatterns #-} -{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE CPP #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE ViewPatterns #-} -- | This module defines data structures and operations pertaining to registered -- servers using registered calls; for unregistered support, see -- `Network.GRPC.LowLevel.Server.Unregistered`. module Network.GRPC.LowLevel.Server where -import Control.Concurrent (ThreadId - , forkFinally - , myThreadId - , killThread) -import Control.Concurrent.STM (atomically - , check) -import Control.Concurrent.STM.TVar (TVar - , modifyTVar' - , readTVar - , writeTVar - , readTVarIO - , newTVarIO) -import Control.Exception (bracket) -import Control.Monad -import Control.Monad.IO.Class -import Control.Monad.Trans.Except -import Data.ByteString (ByteString) -import qualified Data.ByteString as B +import Control.Concurrent ( + ThreadId, + forkFinally, + killThread, + myThreadId, + ) +import Control.Concurrent.STM ( + atomically, + check, + ) +import Control.Concurrent.STM.TVar ( + TVar, + modifyTVar', + newTVarIO, + readTVar, + readTVarIO, + writeTVar, + ) +import Control.Exception (bracket) +import Control.Monad +import Control.Monad.IO.Class +import Control.Monad.Trans.Except +import Data.ByteString (ByteString) +import qualified Data.ByteString as B import qualified Data.Set as S -import Network.GRPC.LowLevel.Call -import Network.GRPC.LowLevel.CompletionQueue (CompletionQueue, - createCompletionQueue, - pluck, - serverRegisterCompletionQueue, - serverRequestCall, - serverShutdownAndNotify, - shutdownCompletionQueue) -import Network.GRPC.LowLevel.GRPC -import Network.GRPC.LowLevel.Op -import qualified Network.GRPC.Unsafe as C -import qualified Network.GRPC.Unsafe.ChannelArgs as C -import qualified Network.GRPC.Unsafe.Op as C -import qualified Network.GRPC.Unsafe.Security as C +import Network.GRPC.LowLevel.Call +import Network.GRPC.LowLevel.CompletionQueue ( + CompletionQueue, + createCompletionQueue, + pluck, + serverRegisterCompletionQueue, + serverRequestCall, + serverShutdownAndNotify, + shutdownCompletionQueue, + ) +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.Op +import qualified Network.GRPC.Unsafe as C +import qualified Network.GRPC.Unsafe.ChannelArgs as C +import qualified Network.GRPC.Unsafe.Op as C +import qualified Network.GRPC.Unsafe.Security as C -- | Wraps various gRPC state needed to run a server. data Server = Server - { serverGRPC :: GRPC - , unsafeServer :: C.Server - , listeningPort :: Port - , serverCQ :: CompletionQueue + { serverGRPC :: GRPC + , unsafeServer :: C.Server + , listeningPort :: Port + , serverCQ :: CompletionQueue -- ^ CQ used for receiving new calls. - , serverCallCQ :: CompletionQueue + , serverCallCQ :: CompletionQueue -- ^ CQ for running ops on calls. Not used to receive new calls. - , normalMethods :: [RegisteredMethod 'Normal] - , sstreamingMethods :: [RegisteredMethod 'ServerStreaming] - , cstreamingMethods :: [RegisteredMethod 'ClientStreaming] + , normalMethods :: [RegisteredMethod 'Normal] + , sstreamingMethods :: [RegisteredMethod 'ServerStreaming] + , cstreamingMethods :: [RegisteredMethod 'ClientStreaming] , bidiStreamingMethods :: [RegisteredMethod 'BiDiStreaming] - , serverConfig :: ServerConfig - , outstandingForks :: TVar (S.Set ThreadId) - , serverShuttingDown :: TVar Bool + , serverConfig :: ServerConfig + , outstandingForks :: TVar (S.Set ThreadId) + , serverShuttingDown :: TVar Bool } -- TODO: should we make a forkGRPC function instead? I am not sure if it would -- be safe to let the call handlers threads keep running after the server stops, -- so I'm taking the more conservative route of ensuring the server will -- stay alive. Experiment more when time permits. + -- | Fork a thread from the server (presumably for a handler) with the guarantee -- that the server won't shut down while the thread is alive. -- Returns true if the fork happens successfully, and false if the server is @@ -83,7 +93,7 @@ forkServer :: Server -> IO () -> IO Bool forkServer Server{..} f = do shutdown <- readTVarIO serverShuttingDown case shutdown of - True -> return False + True -> return False False -> do -- NB: The spawned thread waits on 'ready' before running 'f' to ensure -- that its ThreadId is inserted into outstandingForks before the cleanup @@ -92,44 +102,55 @@ forkServer Server{..} f = do -- subsequent deadlock in stopServer. We can use a dead-list instead if we -- need something more performant. ready <- newTVarIO False - tid <- let act = do atomically (check =<< readTVar ready) - f - in forkFinally act cleanup + tid <- forkFinally (atomically (check =<< readTVar ready) >> f) cleanup + atomically $ do modifyTVar' outstandingForks (S.insert tid) modifyTVar' ready (const True) + + debug + + return True + where + cleanup _ = do + tid <- myThreadId + atomically $ modifyTVar' outstandingForks (S.delete tid) + #ifdef DEBUG + -- This is intentionally moved outside of the 'do' block so that this file + -- can be auto-formatted by fourmolu. Fourmolu's support for formatting CPP + -- fails inside of 'do' blocks. + debug = do tids <- readTVarIO outstandingForks - grpcDebug $ "after fork and bookkeeping: outstandingForks=" ++ show tids + grpcDebug ("after fork and bookkeeping: outstandingForks=" ++ show tids) +# else + debug = pure () #endif - return True - where cleanup _ = do - tid <- myThreadId - atomically $ modifyTVar' outstandingForks (S.delete tid) -- | Configuration for SSL. data ServerSSLConfig = ServerSSLConfig - {clientRootCert :: Maybe FilePath, - serverPrivateKey :: FilePath, - serverCert :: FilePath, - clientCertRequest :: C.SslClientCertificateRequestType, - -- ^ Whether to request a certificate from the client, and what to do with it - -- if received. - customMetadataProcessor :: Maybe C.ProcessMeta} + { clientRootCert :: Maybe FilePath + , serverPrivateKey :: FilePath + , serverCert :: FilePath + , clientCertRequest :: C.SslClientCertificateRequestType + -- ^ Whether to request a certificate from the client, and what to do with it + -- if received. + , customMetadataProcessor :: Maybe C.ProcessMeta + } -- | Configuration needed to start a server. data ServerConfig = ServerConfig - { host :: Host - -- ^ Name of the host the server is running on. Not sure how this is - -- used. Setting to "localhost" works fine in tests. - , port :: Port - -- ^ Port on which to listen for requests. + { host :: Host + -- ^ Name of the host the server is running on. Not sure how this is + -- used. Setting to "localhost" works fine in tests. + , port :: Port + -- ^ Port on which to listen for requests. , methodsToRegisterNormal :: [MethodName] - -- ^ List of normal (non-streaming) methods to register. + -- ^ List of normal (non-streaming) methods to register. , methodsToRegisterClientStreaming :: [MethodName] , methodsToRegisterServerStreaming :: [MethodName] , methodsToRegisterBiDiStreaming :: [MethodName] - , serverArgs :: [C.Arg] + , serverArgs :: [C.Arg] -- ^ Optional arguments for setting up the channel on the server. Supplying an -- empty list will cause the channel to use gRPC's default options. , sslConfig :: Maybe ServerSSLConfig @@ -145,15 +166,17 @@ addPort server conf@ServerConfig{..} = case sslConfig of Nothing -> C.withInsecureServerCredentials $ C.grpcServerAddHttp2Port server e Just ServerSSLConfig{..} -> - do crc <- mapM B.readFile clientRootCert - spk <- B.readFile serverPrivateKey - sc <- B.readFile serverCert - C.withServerCredentials crc spk sc clientCertRequest $ \creds -> do - case customMetadataProcessor of - Just p -> C.setMetadataProcessor creds p - Nothing -> return () - C.grpcServerAddHttp2Port server e creds - where e = unEndpoint $ serverEndpoint conf + do + crc <- mapM B.readFile clientRootCert + spk <- B.readFile serverPrivateKey + sc <- B.readFile serverCert + C.withServerCredentials crc spk sc clientCertRequest $ \creds -> do + case customMetadataProcessor of + Just p -> C.setMetadataProcessor creds p + Nothing -> return () + C.grpcServerAddHttp2Port server e creds + where + e = unEndpoint $ serverEndpoint conf startServer :: GRPC -> ServerConfig -> IO Server startServer grpc conf@ServerConfig{..} = @@ -162,7 +185,8 @@ startServer grpc conf@ServerConfig{..} = server <- C.grpcServerCreate args C.reserved actualPort <- addPort server conf when (unPort port > 0 && actualPort /= unPort port) $ - error $ "Unable to bind port: " ++ show port + error $ + "Unable to bind port: " ++ show port cq <- createCompletionQueue grpc grpcDebug $ "startServer: server CQ: " ++ show cq serverRegisterCompletionQueue server cq @@ -171,26 +195,44 @@ startServer grpc conf@ServerConfig{..} = -- to partition them this way, but we get very convenient phantom typing -- elsewhere by doing so. -- TODO: change order of args so we can eta reduce. - ns <- mapM (\nm -> serverRegisterMethodNormal server nm e) - methodsToRegisterNormal - ss <- mapM (\nm -> serverRegisterMethodServerStreaming server nm e) - methodsToRegisterServerStreaming - cs <- mapM (\nm -> serverRegisterMethodClientStreaming server nm e) - methodsToRegisterClientStreaming - bs <- mapM (\nm -> serverRegisterMethodBiDiStreaming server nm e) - methodsToRegisterBiDiStreaming + ns <- + mapM + (\nm -> serverRegisterMethodNormal server nm e) + methodsToRegisterNormal + ss <- + mapM + (\nm -> serverRegisterMethodServerStreaming server nm e) + methodsToRegisterServerStreaming + cs <- + mapM + (\nm -> serverRegisterMethodClientStreaming server nm e) + methodsToRegisterClientStreaming + bs <- + mapM + (\nm -> serverRegisterMethodBiDiStreaming server nm e) + methodsToRegisterBiDiStreaming C.grpcServerStart server forks <- newTVarIO S.empty shutdown <- newTVarIO False ccq <- createCompletionQueue grpc - return $ Server grpc server (Port actualPort) cq ccq ns ss cs bs conf forks - shutdown - - + return $ + Server + grpc + server + (Port actualPort) + cq + ccq + ns + ss + cs + bs + conf + forks + shutdown stopServer :: Server -> IO () -- TODO: Do method handles need to be freed? -stopServer Server{ unsafeServer = s, .. } = do +stopServer Server{unsafeServer = s, ..} = do grpcDebug "stopServer: calling shutdownNotify." shutdownNotify serverCQ grpcDebug "stopServer: cancelling all calls." @@ -201,34 +243,35 @@ stopServer Server{ unsafeServer = s, .. } = do grpcDebug "stopServer: shutting down CQ." shutdownCQ serverCQ shutdownCQ serverCallCQ - - where shutdownCQ scq = do - shutdownResult <- shutdownCompletionQueue scq - case shutdownResult of - Left _ -> do putStrLn "Warning: completion queue didn't shut down." - putStrLn "Trying to stop server anyway." - Right _ -> return () - shutdownNotify scq = do - let shutdownTag = C.tag 0 - serverShutdownAndNotify s scq shutdownTag - grpcDebug "called serverShutdownAndNotify; plucking." - shutdownEvent <- pluck scq shutdownTag (Just 30) - grpcDebug $ "shutdownNotify: got shutdown event" ++ show shutdownEvent - case shutdownEvent of - -- This case occurs when we pluck but the queue is already in the - -- 'shuttingDown' state, implying we already tried to shut down. - Left GRPCIOShutdown -> error "Called stopServer twice!" - Left _ -> error "Failed to stop server." - Right _ -> return () - cleanupForks = do - atomically $ writeTVar serverShuttingDown True - liveForks <- readTVarIO outstandingForks - grpcDebug $ "Server shutdown: killing threads: " ++ show liveForks - mapM_ killThread liveForks - -- wait for threads to shut down - grpcDebug "Server shutdown: waiting until all threads are dead." - atomically $ check . (==0) . S.size =<< readTVar outstandingForks - grpcDebug "Server shutdown: All forks cleaned up." + where + shutdownCQ scq = do + shutdownResult <- shutdownCompletionQueue scq + case shutdownResult of + Left _ -> do + putStrLn "Warning: completion queue didn't shut down." + putStrLn "Trying to stop server anyway." + Right _ -> return () + shutdownNotify scq = do + let shutdownTag = C.tag 0 + serverShutdownAndNotify s scq shutdownTag + grpcDebug "called serverShutdownAndNotify; plucking." + shutdownEvent <- pluck scq shutdownTag (Just 30) + grpcDebug $ "shutdownNotify: got shutdown event" ++ show shutdownEvent + case shutdownEvent of + -- This case occurs when we pluck but the queue is already in the + -- 'shuttingDown' state, implying we already tried to shut down. + Left GRPCIOShutdown -> error "Called stopServer twice!" + Left _ -> error "Failed to stop server." + Right _ -> return () + cleanupForks = do + atomically $ writeTVar serverShuttingDown True + liveForks <- readTVarIO outstandingForks + grpcDebug $ "Server shutdown: killing threads: " ++ show liveForks + mapM_ killThread liveForks + -- wait for threads to shut down + grpcDebug "Server shutdown: waiting until all threads are dead." + atomically $ check . (== 0) . S.size =<< readTVar outstandingForks + grpcDebug "Server shutdown: All forks cleaned up." -- Uses 'bracket' to safely start and stop a server, even if exceptions occur. withServer :: GRPC -> ServerConfig -> (Server -> IO a) -> IO a @@ -238,16 +281,18 @@ withServer grpc cfg = bracket (startServer grpc cfg) stopServer -- 'serverRegisterMethodNormal', 'serverRegisterMethodServerStreaming', -- 'serverRegisterMethodClientStreaming', and -- 'serverRegisterMethodBiDiStreaming'. -serverRegisterMethod :: C.Server - -> MethodName - -> Endpoint - -> GRPCMethodType - -> IO C.CallHandle +serverRegisterMethod :: + C.Server -> + MethodName -> + Endpoint -> + GRPCMethodType -> + IO C.CallHandle serverRegisterMethod s nm e mty = - C.grpcServerRegisterMethod s - (unMethodName nm) - (unEndpoint e) - (payloadHandling mty) + C.grpcServerRegisterMethod + s + (unMethodName nm) + (unEndpoint e) + (payloadHandling mty) {- TODO: Consolidate the register functions below. @@ -269,82 +314,83 @@ constructor t the function was given. -- to wait for a request to arrive. Note: gRPC claims this must be called before -- the server is started, so we do it during startup according to the -- 'ServerConfig'. -serverRegisterMethodNormal :: C.Server - -> MethodName - -- ^ method name, e.g. "/foo" - -> Endpoint - -- ^ Endpoint name name, e.g. "localhost:9999". I have no - -- idea why this is needed since we have to provide these - -- parameters to start a server in the first place. It - -- doesn't seem to have any effect, even if it's filled - -- with nonsense. - -> IO (RegisteredMethod 'Normal) +serverRegisterMethodNormal :: + C.Server -> + -- | method name, e.g. "/foo" + MethodName -> + -- | Endpoint name name, e.g. "localhost:9999". I have no + -- idea why this is needed since we have to provide these + -- parameters to start a server in the first place. It + -- doesn't seem to have any effect, even if it's filled + -- with nonsense. + Endpoint -> + IO (RegisteredMethod 'Normal) serverRegisterMethodNormal internalServer meth e = do h <- serverRegisterMethod internalServer meth e Normal return $ RegisteredMethodNormal meth e h -serverRegisterMethodClientStreaming - :: C.Server - -> MethodName - -- ^ method name, e.g. "/foo" - -> Endpoint - -- ^ Endpoint name name, e.g. "localhost:9999". I have no - -- idea why this is needed since we have to provide these - -- parameters to start a server in the first place. It - -- doesn't seem to have any effect, even if it's filled - -- with nonsense. - -> IO (RegisteredMethod 'ClientStreaming) +serverRegisterMethodClientStreaming :: + C.Server -> + -- | method name, e.g. "/foo" + MethodName -> + -- | Endpoint name name, e.g. "localhost:9999". I have no + -- idea why this is needed since we have to provide these + -- parameters to start a server in the first place. It + -- doesn't seem to have any effect, even if it's filled + -- with nonsense. + Endpoint -> + IO (RegisteredMethod 'ClientStreaming) serverRegisterMethodClientStreaming internalServer meth e = do h <- serverRegisterMethod internalServer meth e ClientStreaming return $ RegisteredMethodClientStreaming meth e h - -serverRegisterMethodServerStreaming - :: C.Server - -> MethodName - -- ^ method name, e.g. "/foo" - -> Endpoint - -- ^ Endpoint name name, e.g. "localhost:9999". I have no - -- idea why this is needed since we have to provide these - -- parameters to start a server in the first place. It - -- doesn't seem to have any effect, even if it's filled - -- with nonsense. - -> IO (RegisteredMethod 'ServerStreaming) +serverRegisterMethodServerStreaming :: + C.Server -> + -- | method name, e.g. "/foo" + MethodName -> + -- | Endpoint name name, e.g. "localhost:9999". I have no + -- idea why this is needed since we have to provide these + -- parameters to start a server in the first place. It + -- doesn't seem to have any effect, even if it's filled + -- with nonsense. + Endpoint -> + IO (RegisteredMethod 'ServerStreaming) serverRegisterMethodServerStreaming internalServer meth e = do h <- serverRegisterMethod internalServer meth e ServerStreaming return $ RegisteredMethodServerStreaming meth e h - -serverRegisterMethodBiDiStreaming - :: C.Server - -> MethodName - -- ^ method name, e.g. "/foo" - -> Endpoint - -- ^ Endpoint name name, e.g. "localhost:9999". I have no - -- idea why this is needed since we have to provide these - -- parameters to start a server in the first place. It - -- doesn't seem to have any effect, even if it's filled - -- with nonsense. - -> IO (RegisteredMethod 'BiDiStreaming) +serverRegisterMethodBiDiStreaming :: + C.Server -> + -- | method name, e.g. "/foo" + MethodName -> + -- | Endpoint name name, e.g. "localhost:9999". I have no + -- idea why this is needed since we have to provide these + -- parameters to start a server in the first place. It + -- doesn't seem to have any effect, even if it's filled + -- with nonsense. + Endpoint -> + IO (RegisteredMethod 'BiDiStreaming) serverRegisterMethodBiDiStreaming internalServer meth e = do h <- serverRegisterMethod internalServer meth e BiDiStreaming return $ RegisteredMethodBiDiStreaming meth e h -- | Create a 'Call' with which to wait for the invocation of a registered -- method. -serverCreateCall :: Server - -> RegisteredMethod mt - -> IO (Either GRPCIOError (ServerCall (MethodPayload mt))) +serverCreateCall :: + Server -> + RegisteredMethod mt -> + IO (Either GRPCIOError (ServerCall (MethodPayload mt))) serverCreateCall Server{..} rm = serverRequestCall rm unsafeServer serverCQ serverCallCQ -withServerCall :: Server - -> RegisteredMethod mt - -> (ServerCall (MethodPayload mt) -> IO (Either GRPCIOError a)) - -> IO (Either GRPCIOError a) +withServerCall :: + Server -> + RegisteredMethod mt -> + (ServerCall (MethodPayload mt) -> IO (Either GRPCIOError a)) -> + IO (Either GRPCIOError a) withServerCall s rm f = bracket (serverCreateCall s rm) cleanup $ \case - Left e -> return (Left e) + Left e -> return (Left e) Right c -> do debugServerCall c f c @@ -357,55 +403,66 @@ withServerCall s rm f = -------------------------------------------------------------------------------- -- serverReader (server side of client streaming mode) -type ServerReaderHandlerLL - = ServerCall (MethodPayload 'ClientStreaming) - -> StreamRecv ByteString - -> IO (Maybe ByteString, MetadataMap, C.StatusCode, StatusDetails) - -serverReader :: Server - -> RegisteredMethod 'ClientStreaming - -> MetadataMap -- ^ Initial server metadata - -> ServerReaderHandlerLL - -> IO (Either GRPCIOError ()) +type ServerReaderHandlerLL = + ServerCall (MethodPayload 'ClientStreaming) -> + StreamRecv ByteString -> + IO (Maybe ByteString, MetadataMap, C.StatusCode, StatusDetails) + +serverReader :: + Server -> + RegisteredMethod 'ClientStreaming -> + -- | Initial server metadata + MetadataMap -> + ServerReaderHandlerLL -> + IO (Either GRPCIOError ()) serverReader s rm initMeta f = withServerCall s rm (\sc -> serverReader' s sc initMeta f) -serverReader' :: Server - -> ServerCall (MethodPayload 'ClientStreaming) - -> MetadataMap -- ^ Initial server metadata - -> ServerReaderHandlerLL - -> IO (Either GRPCIOError ()) -serverReader' _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = +serverReader' :: + Server -> + ServerCall (MethodPayload 'ClientStreaming) -> + -- | Initial server metadata + MetadataMap -> + ServerReaderHandlerLL -> + IO (Either GRPCIOError ()) +serverReader' _ sc@ServerCall{unsafeSC = c, callCQ = ccq} initMeta f = runExceptT $ do (mmsg, trailMeta, st, ds) <- liftIO $ f sc (streamRecvPrim c ccq) - void $ runOps' c ccq ( OpSendInitialMetadata initMeta - : OpSendStatusFromServer trailMeta st ds - : maybe [] ((:[]) . OpSendMessage) mmsg - ) + void $ + runOps' + c + ccq + ( OpSendInitialMetadata initMeta + : OpSendStatusFromServer trailMeta st ds + : maybe [] ((: []) . OpSendMessage) mmsg + ) -------------------------------------------------------------------------------- -- serverWriter (server side of server streaming mode) -type ServerWriterHandlerLL - = ServerCall (MethodPayload 'ServerStreaming) - -> StreamSend ByteString - -> IO (MetadataMap, C.StatusCode, StatusDetails) +type ServerWriterHandlerLL = + ServerCall (MethodPayload 'ServerStreaming) -> + StreamSend ByteString -> + IO (MetadataMap, C.StatusCode, StatusDetails) -- | Wait for and then handle a registered, server-streaming call. -serverWriter :: Server - -> RegisteredMethod 'ServerStreaming - -> MetadataMap -- ^ Initial server metadata - -> ServerWriterHandlerLL - -> IO (Either GRPCIOError ()) +serverWriter :: + Server -> + RegisteredMethod 'ServerStreaming -> + -- | Initial server metadata + MetadataMap -> + ServerWriterHandlerLL -> + IO (Either GRPCIOError ()) serverWriter s rm initMeta f = withServerCall s rm (\sc -> serverWriter' s sc initMeta f) -serverWriter' :: Server - -> ServerCall (MethodPayload 'ServerStreaming) - -> MetadataMap - -> ServerWriterHandlerLL - -> IO (Either GRPCIOError ()) -serverWriter' _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = +serverWriter' :: + Server -> + ServerCall (MethodPayload 'ServerStreaming) -> + MetadataMap -> + ServerWriterHandlerLL -> + IO (Either GRPCIOError ()) +serverWriter' _ sc@ServerCall{unsafeSC = c, callCQ = ccq} initMeta f = runExceptT $ do sendInitialMetadata c ccq initMeta st <- liftIO $ f sc (streamSendPrim c ccq) @@ -414,26 +471,29 @@ serverWriter' _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = -------------------------------------------------------------------------------- -- serverRW (bidirectional streaming mode) -type ServerRWHandlerLL - = ServerCall (MethodPayload 'BiDiStreaming) - -> StreamRecv ByteString - -> StreamSend ByteString - -> IO (MetadataMap, C.StatusCode, StatusDetails) - -serverRW :: Server - -> RegisteredMethod 'BiDiStreaming - -> MetadataMap -- ^ initial server metadata - -> ServerRWHandlerLL - -> IO (Either GRPCIOError ()) +type ServerRWHandlerLL = + ServerCall (MethodPayload 'BiDiStreaming) -> + StreamRecv ByteString -> + StreamSend ByteString -> + IO (MetadataMap, C.StatusCode, StatusDetails) + +serverRW :: + Server -> + RegisteredMethod 'BiDiStreaming -> + -- | initial server metadata + MetadataMap -> + ServerRWHandlerLL -> + IO (Either GRPCIOError ()) serverRW s rm initMeta f = withServerCall s rm (\sc -> serverRW' s sc initMeta f) -serverRW' :: Server - -> ServerCall (MethodPayload 'BiDiStreaming) - -> MetadataMap - -> ServerRWHandlerLL - -> IO (Either GRPCIOError ()) -serverRW' _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = +serverRW' :: + Server -> + ServerCall (MethodPayload 'BiDiStreaming) -> + MetadataMap -> + ServerRWHandlerLL -> + IO (Either GRPCIOError ()) +serverRW' _ sc@ServerCall{unsafeSC = c, callCQ = ccq} initMeta f = runExceptT $ do sendInitialMetadata c ccq initMeta st <- liftIO $ f sc (streamRecvPrim c ccq) (streamSendPrim c ccq) @@ -448,24 +508,29 @@ serverRW' _ sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = -- values in the result tuple being the initial and trailing metadata -- respectively. We pass in the 'ServerCall' so that the server can call -- 'serverCallCancel' on it if needed. -type ServerHandlerLL - = ServerCall (MethodPayload 'Normal) - -> IO (ByteString, MetadataMap, C.StatusCode, StatusDetails) +type ServerHandlerLL = + ServerCall (MethodPayload 'Normal) -> + IO (ByteString, MetadataMap, C.StatusCode, StatusDetails) -- | Wait for and then handle a normal (non-streaming) call. -serverHandleNormalCall :: Server - -> RegisteredMethod 'Normal - -> MetadataMap - -- ^ Initial server metadata - -> ServerHandlerLL - -> IO (Either GRPCIOError ()) +serverHandleNormalCall :: + Server -> + RegisteredMethod 'Normal -> + -- | Initial server metadata + MetadataMap -> + ServerHandlerLL -> + IO (Either GRPCIOError ()) serverHandleNormalCall s rm initMeta f = withServerCall s rm go where - go sc@ServerCall{ unsafeSC = c, callCQ = ccq } = do + go sc@ServerCall{unsafeSC = c, callCQ = ccq} = do (rsp, trailMeta, st, ds) <- f sc - void <$> runOps c ccq [ OpSendInitialMetadata initMeta - , OpRecvCloseOnServer - , OpSendMessage rsp - , OpSendStatusFromServer trailMeta st ds - ] + void + <$> runOps + c + ccq + [ OpSendInitialMetadata initMeta + , OpRecvCloseOnServer + , OpSendMessage rsp + , OpSendStatusFromServer trailMeta st ds + ] diff --git a/core/src/Network/GRPC/LowLevel/Server/Unregistered.hs b/core/src/Network/GRPC/LowLevel/Server/Unregistered.hs index ae06f1ec..37bcab53 100644 --- a/core/src/Network/GRPC/LowLevel/Server/Unregistered.hs +++ b/core/src/Network/GRPC/LowLevel/Server/Unregistered.hs @@ -1,34 +1,38 @@ -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE RecordWildCards #-} module Network.GRPC.LowLevel.Server.Unregistered where -import Control.Exception (bracket, finally, mask) -import Control.Monad -import Control.Monad.Trans.Except -import Data.ByteString (ByteString) -import Network.GRPC.LowLevel.Call.Unregistered -import Network.GRPC.LowLevel.CompletionQueue.Unregistered (serverRequestCall) -import Network.GRPC.LowLevel.GRPC -import Network.GRPC.LowLevel.Op -import Network.GRPC.LowLevel.Server (Server (..), - ServerRWHandlerLL, - ServerReaderHandlerLL, - ServerWriterHandlerLL, - forkServer, - serverReader', - serverWriter', - serverRW') -import qualified Network.GRPC.Unsafe.Op as C +import Control.Exception (bracket, finally, mask) +import Control.Monad +import Control.Monad.Trans.Except +import Data.ByteString (ByteString) +import Network.GRPC.LowLevel.Call.Unregistered +import Network.GRPC.LowLevel.CompletionQueue.Unregistered (serverRequestCall) +import Network.GRPC.LowLevel.GRPC +import Network.GRPC.LowLevel.Op +import Network.GRPC.LowLevel.Server ( + Server (..), + ServerRWHandlerLL, + ServerReaderHandlerLL, + ServerWriterHandlerLL, + forkServer, + serverRW', + serverReader', + serverWriter', + ) +import qualified Network.GRPC.Unsafe.Op as C -serverCreateCall :: Server - -> IO (Either GRPCIOError ServerCall) +serverCreateCall :: + Server -> + IO (Either GRPCIOError ServerCall) serverCreateCall Server{..} = serverRequestCall unsafeServer serverCQ serverCallCQ -withServerCall :: Server - -> (ServerCall -> IO (Either GRPCIOError a)) - -> IO (Either GRPCIOError a) +withServerCall :: + Server -> + (ServerCall -> IO (Either GRPCIOError a)) -> + IO (Either GRPCIOError a) withServerCall s f = bracket (serverCreateCall s) cleanup $ \case Left e -> return (Left e) @@ -44,100 +48,121 @@ withServerCall s f = -- Handles cleaning up the call safely. -- Because this function doesn't wait for the handler to return, it cannot -- return errors. -withServerCallAsync :: Server - -> (ServerCall -> IO ()) - -> IO () +withServerCallAsync :: + Server -> + (ServerCall -> IO ()) -> + IO () withServerCallAsync s f = mask $ \unmask -> unmask (serverCreateCall s) >>= \case - Left e -> do grpcDebug $ "withServerCallAsync: call error: " ++ show e - return () - Right c -> do wasForkSuccess <- forkServer s handler - unless wasForkSuccess destroy - where handler = unmask (f c) `finally` destroy - -- TODO: We sometimes never finish cleanup if the server - -- is shutting down and calls killThread. This causes gRPC - -- core to complain about leaks. I think the cause of - -- this is that killThread gets called after we are - -- already in destroyServerCall, and wrapping - -- uninterruptibleMask doesn't seem to help. Doesn't - -- crash, but does emit annoying log messages. - destroy = do - grpcDebug "withServerCallAsync: destroying." - destroyServerCall c - grpcDebug "withServerCallAsync: cleanup finished." + Left e -> do + grpcDebug $ "withServerCallAsync: call error: " ++ show e + return () + Right c -> do + wasForkSuccess <- forkServer s handler + unless wasForkSuccess destroy + where + handler = unmask (f c) `finally` destroy + -- TODO: We sometimes never finish cleanup if the server + -- is shutting down and calls killThread. This causes gRPC + -- core to complain about leaks. I think the cause of + -- this is that killThread gets called after we are + -- already in destroyServerCall, and wrapping + -- uninterruptibleMask doesn't seem to help. Doesn't + -- crash, but does emit annoying log messages. + destroy = do + grpcDebug "withServerCallAsync: destroying." + destroyServerCall c + grpcDebug "withServerCallAsync: cleanup finished." -- | A handler for an unregistered server call; bytestring arguments are the -- request body and response body respectively. -type ServerHandler - = ServerCall - -> ByteString - -> IO (ByteString, MetadataMap, C.StatusCode, StatusDetails) +type ServerHandler = + ServerCall -> + ByteString -> + IO (ByteString, MetadataMap, C.StatusCode, StatusDetails) -- | Handle one unregistered call. -serverHandleNormalCall :: Server - -> MetadataMap -- ^ Initial server metadata. - -> ServerHandler - -> IO (Either GRPCIOError ()) +serverHandleNormalCall :: + Server -> + -- | Initial server metadata. + MetadataMap -> + ServerHandler -> + IO (Either GRPCIOError ()) serverHandleNormalCall s initMeta f = withServerCall s $ \c -> serverHandleNormalCall' s c initMeta f -serverHandleNormalCall' :: Server - -> ServerCall - -> MetadataMap -- ^ Initial server metadata. - -> ServerHandler - -> IO (Either GRPCIOError ()) +serverHandleNormalCall' :: + Server -> + ServerCall -> + -- | Initial server metadata. + MetadataMap -> + ServerHandler -> + IO (Either GRPCIOError ()) serverHandleNormalCall' - _ sc@ServerCall{ unsafeSC = c, callCQ = cq, .. } initMeta f = do - grpcDebug "serverHandleNormalCall(U): starting batch." - runOps c cq - [ OpSendInitialMetadata initMeta - , OpRecvMessage - ] - >>= \case - Left x -> do - grpcDebug "serverHandleNormalCall(U): ops failed; aborting" - return $ Left x - Right [OpRecvMessageResult (Just body)] -> do - grpcDebug $ "got client metadata: " ++ show metadata - grpcDebug $ "call_details host is: " ++ show callHost - (rsp, trailMeta, st, ds) <- f sc body - -- TODO: We have to put 'OpRecvCloseOnServer' in the response ops, - -- or else the client times out. Given this, I have no idea how to - -- check for cancellation on the server. - runOps c cq - [ OpRecvCloseOnServer - , OpSendMessage rsp, - OpSendStatusFromServer trailMeta st ds - ] - >>= \case - Left x -> do - grpcDebug "serverHandleNormalCall(U): resp failed." - return $ Left x - Right _ -> do - grpcDebug "serverHandleNormalCall(U): ops done." - return $ Right () - x -> error $ "impossible pattern match: " ++ show x + _ + sc@ServerCall{unsafeSC = c, callCQ = cq, ..} + initMeta + f = do + grpcDebug "serverHandleNormalCall(U): starting batch." + runOps + c + cq + [ OpSendInitialMetadata initMeta + , OpRecvMessage + ] + >>= \case + Left x -> do + grpcDebug "serverHandleNormalCall(U): ops failed; aborting" + return $ Left x + Right [OpRecvMessageResult (Just body)] -> do + grpcDebug $ "got client metadata: " ++ show metadata + grpcDebug $ "call_details host is: " ++ show callHost + (rsp, trailMeta, st, ds) <- f sc body + -- TODO: We have to put 'OpRecvCloseOnServer' in the response ops, + -- or else the client times out. Given this, I have no idea how to + -- check for cancellation on the server. + runOps + c + cq + [ OpRecvCloseOnServer + , OpSendMessage rsp + , OpSendStatusFromServer trailMeta st ds + ] + >>= \case + Left x -> do + grpcDebug "serverHandleNormalCall(U): resp failed." + return $ Left x + Right _ -> do + grpcDebug "serverHandleNormalCall(U): ops done." + return $ Right () + x -> error $ "impossible pattern match: " ++ show x -serverReader :: Server - -> ServerCall - -> MetadataMap -- ^ Initial server metadata - -> ServerReaderHandlerLL - -> IO (Either GRPCIOError ()) +serverReader :: + Server -> + ServerCall -> + -- | Initial server metadata + MetadataMap -> + ServerReaderHandlerLL -> + IO (Either GRPCIOError ()) serverReader s = serverReader' s . convertCall -serverWriter :: Server - -> ServerCall - -> MetadataMap -- ^ Initial server metadata - -> ServerWriterHandlerLL - -> IO (Either GRPCIOError ()) -serverWriter s sc@ServerCall{ unsafeSC = c, callCQ = ccq } initMeta f = +serverWriter :: + Server -> + ServerCall -> + -- | Initial server metadata + MetadataMap -> + ServerWriterHandlerLL -> + IO (Either GRPCIOError ()) +serverWriter s sc@ServerCall{unsafeSC = c, callCQ = ccq} initMeta f = runExceptT $ do bs <- recvInitialMessage c ccq ExceptT (serverWriter' s (const bs <$> convertCall sc) initMeta f) -serverRW :: Server - -> ServerCall - -> MetadataMap -- ^ Initial server metadata - -> ServerRWHandlerLL - -> IO (Either GRPCIOError ()) +serverRW :: + Server -> + ServerCall -> + -- | Initial server metadata + MetadataMap -> + ServerRWHandlerLL -> + IO (Either GRPCIOError ()) serverRW s = serverRW' s . convertCall diff --git a/core/tests/LowLevelTests.hs b/core/tests/LowLevelTests.hs index 516f5aed..7e82e27b 100644 --- a/core/tests/LowLevelTests.hs +++ b/core/tests/LowLevelTests.hs @@ -1,70 +1,76 @@ -{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DataKinds #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedLists #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE PatternSynonyms #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE TupleSections #-} -{-# OPTIONS_GHC -fno-warn-orphans #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE PatternSynonyms #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TupleSections #-} +{-# OPTIONS_GHC -fno-warn-orphans #-} module LowLevelTests where -import Control.Concurrent (threadDelay) -import Control.Concurrent.Async -import Control.Monad -import Control.Monad.Managed -import Data.ByteString (ByteString, - isPrefixOf, - isSuffixOf) -import Data.List (find) -import qualified Data.Map.Strict as M -import qualified Data.Set as S -import GHC.Exts (fromList, toList) -import Network.GRPC.LowLevel -import qualified Network.GRPC.LowLevel.Call.Unregistered as U +import Control.Concurrent (threadDelay) +import Control.Concurrent.Async +import Control.Monad +import Control.Monad.Managed +import Data.ByteString ( + ByteString, + isPrefixOf, + isSuffixOf, + ) +import Data.List (find) +import qualified Data.Map.Strict as M +import qualified Data.Set as S +import GHC.Exts (fromList, toList) +import Network.GRPC.LowLevel +import qualified Network.GRPC.LowLevel.Call.Unregistered as U import qualified Network.GRPC.LowLevel.Client.Unregistered as U -import Network.GRPC.LowLevel.GRPC (threadDelaySecs) +import Network.GRPC.LowLevel.GRPC (threadDelaySecs) import qualified Network.GRPC.LowLevel.Server.Unregistered as U -import qualified Pipes as P -import Test.Tasty -import Test.Tasty.HUnit as HU (Assertion, - assertBool, - assertEqual, - assertFailure, - testCase, - (@?=)) +import qualified Pipes as P +import Test.Tasty +import Test.Tasty.HUnit as HU ( + Assertion, + assertBool, + assertEqual, + assertFailure, + testCase, + (@?=), + ) lowLevelTests :: TestTree -lowLevelTests = testGroup "Unit tests of low-level Haskell library" - [ testGRPCBracket - , testCompletionQueueCreateDestroy - , testClientCreateDestroy - , testClientCall - , testClientTimeoutNoServer - , testServerCreateDestroy - , testMixRegisteredUnregistered - , testPayload - , testSSL - , testAuthMetadataTransfer - , testServerAuthProcessorCancel - , testPayloadUnregistered - , testServerCancel - , testGoaway - , testSlowServer - , testServerCallExpirationCheck - , testCustomUserAgent - , testClientCompression - , testClientServerCompression - , testClientMaxReceiveMessageLengthChannelArg - , testClientStreaming - , testClientStreamingUnregistered - , testServerStreaming - , testServerStreamingUnregistered - , testBiDiStreaming - , testBiDiStreamingUnregistered - ] +lowLevelTests = + testGroup + "Unit tests of low-level Haskell library" + [ testGRPCBracket + , testCompletionQueueCreateDestroy + , testClientCreateDestroy + , testClientCall + , testClientTimeoutNoServer + , testServerCreateDestroy + , testMixRegisteredUnregistered + , testPayload + , testSSL + , testAuthMetadataTransfer + , testServerAuthProcessorCancel + , testPayloadUnregistered + , testServerCancel + , testGoaway + , testSlowServer + , testServerCallExpirationCheck + , testCustomUserAgent + , testClientCompression + , testClientServerCompression + , testClientMaxReceiveMessageLengthChannelArg + , testClientStreaming + , testClientStreamingUnregistered + , testServerStreaming + , testServerStreamingUnregistered + , testBiDiStreaming + , testBiDiStreamingUnregistered + ] testGRPCBracket :: TestTree testGRPCBracket = @@ -84,19 +90,20 @@ testClientTimeoutNoServer :: TestTree testClientTimeoutNoServer = clientOnlyTest "request timeout when server DNE" $ \c -> do rm <- clientRegisterMethodNormal c "/foo" - r <- clientRequest c rm 1 "Hello" mempty + r <- clientRequest c rm 1 "Hello" mempty r @?= Left GRPCIOTimeout testServerCreateDestroy :: TestTree testServerCreateDestroy = - serverOnlyTest "start/stop" (["/foo"],[],[],[]) nop + serverOnlyTest "start/stop" (["/foo"], [], [], []) nop testMixRegisteredUnregistered :: TestTree testMixRegisteredUnregistered = - csTest "server uses unregistered calls to handle unknown endpoints" - client - server - (["/foo"],[],[],[]) + csTest + "server uses unregistered calls to handle unknown endpoints" + client + server + (["/foo"], [], [], []) where client c = do rm1 <- clientRegisterMethodNormal c "/foo" @@ -111,20 +118,25 @@ testMixRegisteredUnregistered = rspBody @?= "" return () server s = do - concurrently regThread unregThread - return () - where regThread = do - let rm = head (normalMethods s) - _r <- serverHandleNormalCall s rm dummyMeta $ \c -> do - payload c @?= "Hello" - return ("reply test", dummyMeta, StatusOk, "") - return () - unregThread = do - U.serverHandleNormalCall s mempty $ \call _ -> do - U.callMethod call @?= "/bar" - return ("", mempty, StatusOk, - StatusDetails "Wrong endpoint") - return () + concurrently regThread unregThread + return () + where + regThread = do + let rm = head (normalMethods s) + _r <- serverHandleNormalCall s rm dummyMeta $ \c -> do + payload c @?= "Hello" + return ("reply test", dummyMeta, StatusOk, "") + return () + unregThread = do + U.serverHandleNormalCall s mempty $ \call _ -> do + U.callMethod call @?= "/bar" + return + ( "" + , mempty + , StatusOk + , StatusDetails "Wrong endpoint" + ) + return () -- TODO: There seems to be a race here (and in other client/server pairs, of -- course) about what gets reported when there is a failure. E.g., if one of the @@ -134,11 +146,13 @@ testMixRegisteredUnregistered = -- tweak EH behavior / async use. testPayload :: TestTree testPayload = - csTest "registered normal request/response" client server (["/foo"],[],[],[]) + csTest "registered normal request/response" client server (["/foo"], [], [], []) where - clientMD = [ ("foo_key", "foo_val") - , ("bar_key", "bar_val") - , ("bar_key", "bar_repeated_val")] + clientMD = + [ ("foo_key", "foo_val") + , ("bar_key", "bar_val") + , ("bar_key", "bar_repeated_val") + ] client c = do rm <- clientRegisterMethodNormal c "/foo" clientRequest c rm 10 "Hello!" clientMD >>= do @@ -146,7 +160,7 @@ testPayload = rspCode @?= StatusOk rspBody @?= "reply test" details @?= "details string" - initMD @?= dummyMeta + initMD @?= dummyMeta trailMD @?= dummyMeta server s = do let rm = head (normalMethods s) @@ -160,12 +174,16 @@ testSSL :: TestTree testSSL = csTest' "request/response using SSL" client server where - clientConf = stdClientConf - {clientSSLConfig = Just (ClientSSLConfig - (Just "tests/ssl/localhost.crt") - Nothing - Nothing) - } + clientConf = + stdClientConf + { clientSSLConfig = + Just + ( ClientSSLConfig + (Just "tests/ssl/localhost.crt") + Nothing + Nothing + ) + } client = TestClient clientConf $ \c -> do rm <- clientRegisterMethodNormal c "/foo" clientRequest c rm 10 "hi" mempty >>= do @@ -173,14 +191,18 @@ testSSL = rspCode @?= StatusOk rspBody @?= "reply test" - serverConf' = defServerConf - { sslConfig = Just (ServerSSLConfig - Nothing - "tests/ssl/localhost.key" - "tests/ssl/localhost.crt" - SslDontRequestClientCertificate - Nothing) - } + serverConf' = + defServerConf + { sslConfig = + Just + ( ServerSSLConfig + Nothing + "tests/ssl/localhost.key" + "tests/ssl/localhost.crt" + SslDontRequestClientCertificate + Nothing + ) + } server = TestServer serverConf' $ \s -> do r <- U.serverHandleNormalCall s mempty $ \U.ServerCall{} body -> do body @?= "hi" @@ -197,41 +219,51 @@ testServerAuthProcessorCancel :: TestTree testServerAuthProcessorCancel = csTest' "request rejection by auth processor" client server where - clientConf = stdClientConf - {clientSSLConfig = Just (ClientSSLConfig - (Just "tests/ssl/localhost.crt") - Nothing - Nothing) - } + clientConf = + stdClientConf + { clientSSLConfig = + Just + ( ClientSSLConfig + (Just "tests/ssl/localhost.crt") + Nothing + Nothing + ) + } client = TestClient clientConf $ \c -> do rm <- clientRegisterMethodNormal c "/foo" r <- clientRequest c rm 10 "hi" mempty -- TODO: using checkReqRslt on this first result causes the test to hang! r @?= Left (GRPCIOBadStatusCode StatusUnauthenticated "denied!") - clientRequest c rm 10 "hi" [("foo","bar")] >>= do + clientRequest c rm 10 "hi" [("foo", "bar")] >>= do checkReqRslt $ \NormalRequestResult{..} -> do rspCode @?= StatusOk rspBody @?= "reply test" serverProcessor = Just $ \_ m -> do - let (status, details) = if M.member "foo" (unMap m) - then (StatusOk, "") - else (StatusUnauthenticated, "denied!") + let (status, details) = + if M.member "foo" (unMap m) + then (StatusOk, "") + else (StatusUnauthenticated, "denied!") return $ AuthProcessorResult mempty mempty status details - serverConf' = defServerConf - { sslConfig = Just (ServerSSLConfig - Nothing - "tests/ssl/localhost.key" - "tests/ssl/localhost.crt" - SslDontRequestClientCertificate - serverProcessor) - } + serverConf' = + defServerConf + { sslConfig = + Just + ( ServerSSLConfig + Nothing + "tests/ssl/localhost.key" + "tests/ssl/localhost.crt" + SslDontRequestClientCertificate + serverProcessor + ) + } server = TestServer serverConf' $ \s -> do r <- U.serverHandleNormalCall s mempty $ \U.ServerCall{..} _body -> do - checkMD "Handler only sees requests with good metadata" - [("foo","bar")] - metadata + checkMD + "Handler only sees requests with good metadata" + [("foo", "bar")] + metadata return ("reply test", mempty, StatusOk, "") r @?= Right () @@ -247,13 +279,17 @@ testAuthMetadataTransfer = newProps <- getAuthProperties authCtx let addedProp = find ((== "foo1") . authPropName) newProps addedProp @?= Just (AuthProperty "foo1" "bar1") - return $ ClientMetadataCreateResult [("foo","bar")] StatusOk "" - clientConf = stdClientConf - {clientSSLConfig = Just (ClientSSLConfig - (Just "tests/ssl/localhost.crt") - Nothing - (Just plugin)) - } + return $ ClientMetadataCreateResult [("foo", "bar")] StatusOk "" + clientConf = + stdClientConf + { clientSSLConfig = + Just + ( ClientSSLConfig + (Just "tests/ssl/localhost.crt") + Nothing + (Just plugin) + ) + } client = TestClient clientConf $ \c -> do rm <- clientRegisterMethodNormal c "/foo" clientRequest c rm 10 "hi" mempty >>= do @@ -263,23 +299,28 @@ testAuthMetadataTransfer = serverProcessor :: Maybe ProcessMeta serverProcessor = Just $ \authCtx m -> do - let expected = fromList [("foo","bar")] + let expected = fromList [("foo", "bar")] props <- getAuthProperties authCtx let clientProp = find ((== "foo1") . authPropName) props - assertBool "server plugin doesn't see auth properties set by client" - (clientProp == Nothing) + assertBool + "server plugin doesn't see auth properties set by client" + (clientProp == Nothing) checkMD "server plugin sees metadata added by client plugin" expected m return $ AuthProcessorResult mempty mempty StatusOk "" - serverConf' = defServerConf - { sslConfig = Just (ServerSSLConfig - Nothing - "tests/ssl/localhost.key" - "tests/ssl/localhost.crt" - SslDontRequestClientCertificate - serverProcessor) - } + serverConf' = + defServerConf + { sslConfig = + Just + ( ServerSSLConfig + Nothing + "tests/ssl/localhost.key" + "tests/ssl/localhost.crt" + SslDontRequestClientCertificate + serverProcessor + ) + } server = TestServer serverConf' $ \s -> do r <- U.serverHandleNormalCall s mempty $ \U.ServerCall{} body -> do body @?= "hi" @@ -300,13 +341,17 @@ testAuthMetadataPropagate = testCase "auth metadata inherited by children" $ do return () where clientPlugin _ = - return $ ClientMetadataCreateResult [("foo","bar")] StatusOk "" - clientConf = stdClientConf - {clientSSLConfig = Just (ClientSSLConfig - (Just "tests/ssl/localhost.crt") - Nothing - (Just clientPlugin)) - } + return $ ClientMetadataCreateResult [("foo", "bar")] StatusOk "" + clientConf = + stdClientConf + { clientSSLConfig = + Just + ( ClientSSLConfig + (Just "tests/ssl/localhost.crt") + Nothing + (Just clientPlugin) + ) + } client = do threadDelaySecs 3 withGRPC $ \g -> withClient g clientConf $ \c -> do @@ -317,30 +362,38 @@ testAuthMetadataPropagate = testCase "auth metadata inherited by children" $ do rspBody @?= "reply test" server1ServerPlugin _ctx md = do - checkMD "server1 sees client's auth metadata." [("foo","bar")] md + checkMD "server1 sees client's auth metadata." [("foo", "bar")] md -- TODO: add response meta to check, and consume meta to see what happens. return $ AuthProcessorResult mempty mempty StatusOk "" - server1ServerConf = defServerConf - {sslConfig = Just (ServerSSLConfig - Nothing - "tests/ssl/localhost.key" - "tests/ssl/localhost.crt" - SslDontRequestClientCertificate - (Just server1ServerPlugin)), - methodsToRegisterNormal = ["/foo"] - } + server1ServerConf = + defServerConf + { sslConfig = + Just + ( ServerSSLConfig + Nothing + "tests/ssl/localhost.key" + "tests/ssl/localhost.crt" + SslDontRequestClientCertificate + (Just server1ServerPlugin) + ) + , methodsToRegisterNormal = ["/foo"] + } server1ClientPlugin _ = - return $ ClientMetadataCreateResult [("foo1","bar1")] StatusOk "" - - server1ClientConf = stdClientConf - {clientSSLConfig = Just (ClientSSLConfig - (Just "tests/ssl/localhost.crt") - Nothing - (Just server1ClientPlugin)), - clientServerEndpoint = "localhost:50052" - } + return $ ClientMetadataCreateResult [("foo1", "bar1")] StatusOk "" + + server1ClientConf = + stdClientConf + { clientSSLConfig = + Just + ( ClientSSLConfig + (Just "tests/ssl/localhost.crt") + Nothing + (Just server1ClientPlugin) + ) + , clientServerEndpoint = "localhost:50052" + } server = do threadDelaySecs 2 @@ -358,30 +411,34 @@ testAuthMetadataPropagate = testCase "auth metadata inherited by children" $ do server2ServerPlugin _ctx md = do print md - checkMD "server2 sees server1's auth metadata." [("foo1","bar1")] md - --TODO: this assert fails - checkMD "server2 sees client's auth metadata." [("foo","bar")] md + checkMD "server2 sees server1's auth metadata." [("foo1", "bar1")] md + -- TODO: this assert fails + checkMD "server2 sees client's auth metadata." [("foo", "bar")] md return $ AuthProcessorResult mempty mempty StatusOk "" - server2ServerConf = defServerConf - {sslConfig = Just (ServerSSLConfig - Nothing - "tests/ssl/localhost.key" - "tests/ssl/localhost.crt" - SslDontRequestClientCertificate - (Just server2ServerPlugin)), - methodsToRegisterNormal = ["/foo"], - port = 50052 - } + server2ServerConf = + defServerConf + { sslConfig = + Just + ( ServerSSLConfig + Nothing + "tests/ssl/localhost.key" + "tests/ssl/localhost.crt" + SslDontRequestClientCertificate + (Just server2ServerPlugin) + ) + , methodsToRegisterNormal = ["/foo"] + , port = 50052 + } server2 = withGRPC $ \g -> withServer g server2ServerConf $ \s -> do - let rm = head (normalMethods s) - serverHandleNormalCall s rm mempty $ \_call -> do - return ("server2 reply", mempty, StatusOk, "") + let rm = head (normalMethods s) + serverHandleNormalCall s rm mempty $ \_call -> do + return ("server2 reply", mempty, StatusOk, "") testServerCancel :: TestTree testServerCancel = - csTest "server cancel call" client server (["/foo"],[],[],[]) + csTest "server cancel call" client server (["/foo"], [], [], []) where client c = do rm <- clientRegisterMethodNormal c "/foo" @@ -396,12 +453,12 @@ testServerCancel = testServerStreaming :: TestTree testServerStreaming = - csTest "server streaming" client server ([],[],["/feed"],[]) + csTest "server streaming" client server ([], [], ["/feed"], []) where - clientInitMD = [("client","initmd")] - serverInitMD = [("server","initmd")] - clientPay = "FEED ME!" - pays = ["ONE", "TWO", "THREE", "FOUR"] :: [ByteString] + clientInitMD = [("client", "initmd")] + serverInitMD = [("server", "initmd")] + clientPay = "FEED ME!" + pays = ["ONE", "TWO", "THREE", "FOUR"] :: [ByteString] client c = do rm <- clientRegisterMethodServerStreaming c "/feed" @@ -426,12 +483,12 @@ testServerStreaming = -- to using them in these tests. testServerStreamingUnregistered :: TestTree testServerStreamingUnregistered = - csTest "unregistered server streaming" client server ([],[],[],[]) + csTest "unregistered server streaming" client server ([], [], [], []) where - clientInitMD = [("client","initmd")] - serverInitMD = [("server","initmd")] - clientPay = "FEED ME!" - pays = ["ONE", "TWO", "THREE", "FOUR"] :: [ByteString] + clientInitMD = [("client", "initmd")] + serverInitMD = [("server", "initmd")] + clientPay = "FEED ME!" + pays = ["ONE", "TWO", "THREE", "FOUR"] :: [ByteString] client c = do rm <- clientRegisterMethodServerStreaming c "/feed" @@ -451,18 +508,18 @@ testServerStreamingUnregistered = testClientStreaming :: TestTree testClientStreaming = - csTest "client streaming" client server ([],["/slurp"],[],[]) + csTest "client streaming" client server ([], ["/slurp"], [], []) where - clientInitMD = [("a","b")] - serverInitMD = [("x","y")] - trailMD = dummyMeta - serverRsp = "serverReader reply" - serverDtls = "deets" + clientInitMD = [("a", "b")] + serverInitMD = [("x", "y")] + trailMD = dummyMeta + serverRsp = "serverReader reply" + serverDtls = "deets" serverStatus = StatusOk - pays = ["P_ONE", "P_TWO", "P_THREE"] :: [ByteString] + pays = ["P_ONE", "P_TWO", "P_THREE"] :: [ByteString] client c = do - rm <- clientRegisterMethodClientStreaming c "/slurp" + rm <- clientRegisterMethodClientStreaming c "/slurp" eea <- clientWriter c rm 10 clientInitMD $ \send -> do -- liftIO $ checkMD "Server initial metadata mismatch" serverInitMD initMD forM_ pays $ \p -> send p `is` Right () @@ -479,18 +536,18 @@ testClientStreaming = testClientStreamingUnregistered :: TestTree testClientStreamingUnregistered = - csTest "unregistered client streaming" client server ([],[],[],[]) + csTest "unregistered client streaming" client server ([], [], [], []) where - clientInitMD = [("a","b")] - serverInitMD = [("x","y")] - trailMD = dummyMeta - serverRsp = "serverReader reply" - serverDtls = "deets" + clientInitMD = [("a", "b")] + serverInitMD = [("x", "y")] + trailMD = dummyMeta + serverRsp = "serverReader reply" + serverDtls = "deets" serverStatus = StatusOk - pays = ["P_ONE", "P_TWO", "P_THREE"] :: [ByteString] + pays = ["P_ONE", "P_TWO", "P_THREE"] :: [ByteString] client c = do - rm <- clientRegisterMethodClientStreaming c "/slurp" + rm <- clientRegisterMethodClientStreaming c "/slurp" eea <- clientWriter c rm 10 clientInitMD $ \send -> do -- liftIO $ checkMD "Server initial metadata mismatch" serverInitMD initMD forM_ pays $ \p -> send p `is` Right () @@ -506,72 +563,72 @@ testClientStreamingUnregistered = testBiDiStreaming :: TestTree testBiDiStreaming = - csTest "bidirectional streaming" client server ([],[],[],["/bidi"]) + csTest "bidirectional streaming" client server ([], [], [], ["/bidi"]) where - clientInitMD = [("bidi-streaming","client")] - serverInitMD = [("bidi-streaming","server")] - trailMD = dummyMeta + clientInitMD = [("bidi-streaming", "client")] + serverInitMD = [("bidi-streaming", "server")] + trailMD = dummyMeta serverStatus = StatusOk - serverDtls = "deets" + serverDtls = "deets" client c = do - rm <- clientRegisterMethodBiDiStreaming c "/bidi" + rm <- clientRegisterMethodBiDiStreaming c "/bidi" eea <- clientRW c rm 10 clientInitMD $ \_cc getMD recv send writesDone -> do either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD send "cw0" `is` Right () - recv `is` Right (Just "sw0") + recv `is` Right (Just "sw0") send "cw1" `is` Right () - recv `is` Right (Just "sw1") - recv `is` Right (Just "sw2") + recv `is` Right (Just "sw1") + recv `is` Right (Just "sw2") writesDone `is` Right () - recv `is` Right Nothing + recv `is` Right Nothing eea @?= Right (trailMD, serverStatus, serverDtls) server s = do let rm = head (bidiStreamingMethods s) eea <- serverRW s rm serverInitMD $ \sc recv send -> do checkMD "Client request metadata mismatch" clientInitMD (metadata sc) - recv `is` Right (Just "cw0") + recv `is` Right (Just "cw0") send "sw0" `is` Right () - recv `is` Right (Just "cw1") + recv `is` Right (Just "cw1") send "sw1" `is` Right () send "sw2" `is` Right () - recv `is` Right Nothing + recv `is` Right Nothing return (trailMD, serverStatus, serverDtls) eea @?= Right () testBiDiStreamingUnregistered :: TestTree testBiDiStreamingUnregistered = - csTest "unregistered bidirectional streaming" client server ([],[],[],[]) + csTest "unregistered bidirectional streaming" client server ([], [], [], []) where - clientInitMD = [("bidi-streaming","client")] - serverInitMD = [("bidi-streaming","server")] - trailMD = dummyMeta + clientInitMD = [("bidi-streaming", "client")] + serverInitMD = [("bidi-streaming", "server")] + trailMD = dummyMeta serverStatus = StatusOk - serverDtls = "deets" + serverDtls = "deets" client c = do - rm <- clientRegisterMethodBiDiStreaming c "/bidi" + rm <- clientRegisterMethodBiDiStreaming c "/bidi" eea <- clientRW c rm 10 clientInitMD $ \_cc getMD recv send writesDone -> do either clientFail (checkMD "Server rsp metadata mismatch" serverInitMD) =<< getMD send "cw0" `is` Right () - recv `is` Right (Just "sw0") + recv `is` Right (Just "sw0") send "cw1" `is` Right () - recv `is` Right (Just "sw1") - recv `is` Right (Just "sw2") + recv `is` Right (Just "sw1") + recv `is` Right (Just "sw2") writesDone `is` Right () - recv `is` Right Nothing + recv `is` Right Nothing eea @?= Right (trailMD, serverStatus, serverDtls) server s = U.withServerCallAsync s $ \call -> do eea <- U.serverRW s call serverInitMD $ \sc recv send -> do checkMD "Client request metadata mismatch" clientInitMD (metadata sc) - recv `is` Right (Just "cw0") + recv `is` Right (Just "cw0") send "sw0" `is` Right () - recv `is` Right (Just "cw1") + recv `is` Right (Just "cw1") send "sw1" `is` Right () send "sw2" `is` Right () - recv `is` Right Nothing + recv `is` Right Nothing return (trailMD, serverStatus, serverDtls) eea @?= Right () @@ -586,13 +643,13 @@ testClientCall = testServerCall :: TestTree testServerCall = - serverOnlyTest "create/destroy call" ([],[],[],[]) $ \s -> do + serverOnlyTest "create/destroy call" ([], [], [], []) $ \s -> do r <- U.withServerCall s $ const $ return $ Right () r @?= Left GRPCIOTimeout testPayloadUnregistered :: TestTree testPayloadUnregistered = - csTest "unregistered normal request/response" client server ([],[],[],[]) + csTest "unregistered normal request/response" client server ([], [], [], []) where client c = U.clientRequest c "/foo" 10 "Hello!" mempty >>= do @@ -602,17 +659,18 @@ testPayloadUnregistered = details @?= "details string" server s = do r <- U.serverHandleNormalCall s mempty $ \U.ServerCall{..} body -> do - body @?= "Hello!" - callMethod @?= "/foo" - return ("reply test", mempty, StatusOk, "details string") + body @?= "Hello!" + callMethod @?= "/foo" + return ("reply test", mempty, StatusOk, "details string") r @?= Right () testGoaway :: TestTree testGoaway = - csTest "Client handles server shutdown gracefully" - client - server - (["/foo"],[],[],[]) + csTest + "Client handles server shutdown gracefully" + client + server + (["/foo"], [], [], []) where client c = do rm <- clientRegisterMethodNormal c "/foo" @@ -620,10 +678,10 @@ testGoaway = clientRequest c rm 10 "" mempty eer <- clientRequest c rm 1 "" mempty assertBool "Client handles server shutdown gracefully" $ case eer of - Left (GRPCIOBadStatusCode StatusUnavailable _) -> True + Left (GRPCIOBadStatusCode StatusUnavailable _) -> True Left (GRPCIOBadStatusCode StatusDeadlineExceeded "Deadline Exceeded") -> True - Left GRPCIOTimeout -> True - _ -> False + Left GRPCIOTimeout -> True + _ -> False server s = do let rm = head (normalMethods s) @@ -633,7 +691,7 @@ testGoaway = testSlowServer :: TestTree testSlowServer = - csTest "Client handles slow server response" client server (["/foo"],[],[],[]) + csTest "Client handles slow server response" client server (["/foo"], [], [], []) where client c = do rm <- clientRegisterMethodNormal c "/foo" @@ -642,13 +700,13 @@ testSlowServer = server s = do let rm = head (normalMethods s) serverHandleNormalCall s rm mempty $ \_ -> do - threadDelay (2*10^(6 :: Int)) + threadDelay (2 * 10 ^ (6 :: Int)) return dummyResp return () testServerCallExpirationCheck :: TestTree testServerCallExpirationCheck = - csTest "Check for call expiration" client server (["/foo"],[],[],[]) + csTest "Check for call expiration" client server (["/foo"], [], [], []) where client c = do rm <- clientRegisterMethodNormal c "/foo" @@ -674,15 +732,16 @@ testCustomUserAgent = clientArgs = [UserAgentPrefix "prefix!", UserAgentSuffix "suffix!"] client = TestClient (ClientConfig "localhost:50051" clientArgs Nothing Nothing) $ - \c -> do rm <- clientRegisterMethodNormal c "/foo" - void $ clientRequest c rm 4 "" mempty - server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do + \c -> do + rm <- clientRegisterMethodNormal c "/foo" + void $ clientRequest c rm 4 "" mempty + server = TestServer (serverConf (["/foo"], [], [], [])) $ \s -> do let rm = head (normalMethods s) serverHandleNormalCall s rm mempty $ \c -> do ua <- case toList $ (unMap $ metadata c) M.! "user-agent" of - [] -> fail "user-agent missing from metadata." - [ua] -> return ua - _ -> fail "multiple user-agent keys." + [] -> fail "user-agent missing from metadata." + [ua] -> return ua + _ -> fail "multiple user-agent keys." assertBool "User agent prefix is present" $ isPrefixOf "prefix!" ua assertBool "User agent suffix is present" $ isSuffixOf "suffix!" ua return dummyResp @@ -693,14 +752,17 @@ testClientCompression = csTest' "client-only compression: no errors" client server where client = - TestClient (ClientConfig - "localhost:50051" - [CompressionAlgArg GrpcCompressDeflate] - Nothing - Nothing) $ \c -> do - rm <- clientRegisterMethodNormal c "/foo" - void $ clientRequest c rm 1 "hello" mempty - server = TestServer (serverConf (["/foo"],[],[],[])) $ \s -> do + TestClient + ( ClientConfig + "localhost:50051" + [CompressionAlgArg GrpcCompressDeflate] + Nothing + Nothing + ) + $ \c -> do + rm <- clientRegisterMethodNormal c "/foo" + void $ clientRequest c rm 1 "hello" mempty + server = TestServer (serverConf (["/foo"], [], [], [])) $ \s -> do let rm = head (normalMethods s) serverHandleNormalCall s rm mempty $ \c -> do payload c @?= "hello" @@ -711,10 +773,12 @@ testClientServerCompression :: TestTree testClientServerCompression = csTest' "client/server compression: no errors" client server where - cconf = ClientConfig "localhost:50051" - [CompressionAlgArg GrpcCompressDeflate] - Nothing - Nothing + cconf = + ClientConfig + "localhost:50051" + [CompressionAlgArg GrpcCompressDeflate] + Nothing + Nothing client = TestClient cconf $ \c -> do rm <- clientRegisterMethodNormal c "/foo" clientRequest c rm 1 "hello" mempty >>= do @@ -722,14 +786,19 @@ testClientServerCompression = rspCode @?= StatusOk rspBody @?= "hello" details @?= "" - initMD @?= dummyMeta + initMD @?= dummyMeta trailMD @?= dummyMeta return () - sconf = ServerConfig "localhost" - 50051 - ["/foo"] [] [] [] - [CompressionAlgArg GrpcCompressDeflate] - Nothing + sconf = + ServerConfig + "localhost" + 50051 + ["/foo"] + [] + [] + [] + [CompressionAlgArg GrpcCompressDeflate] + Nothing server = TestServer sconf $ \s -> do let rm = head (normalMethods s) serverHandleNormalCall s rm dummyMeta $ \sc -> do @@ -741,10 +810,12 @@ testClientServerCompressionLvl :: TestTree testClientServerCompressionLvl = csTest' "client/server compression: no errors" client server where - cconf = ClientConfig "localhost:50051" - [CompressionLevelArg GrpcCompressLevelHigh] - Nothing - Nothing + cconf = + ClientConfig + "localhost:50051" + [CompressionLevelArg GrpcCompressLevelHigh] + Nothing + Nothing client = TestClient cconf $ \c -> do rm <- clientRegisterMethodNormal c "/foo" clientRequest c rm 1 "hello" mempty >>= do @@ -752,14 +823,19 @@ testClientServerCompressionLvl = rspCode @?= StatusOk rspBody @?= "hello" details @?= "" - initMD @?= dummyMeta + initMD @?= dummyMeta trailMD @?= dummyMeta return () - sconf = ServerConfig "localhost" - 50051 - ["/foo"] [] [] [] - [CompressionLevelArg GrpcCompressLevelLow] - Nothing + sconf = + ServerConfig + "localhost" + 50051 + ["/foo"] + [] + [] + [] + [CompressionLevelArg GrpcCompressLevelLow] + Nothing server = TestServer sconf $ \s -> do let rm = head (normalMethods s) serverHandleNormalCall s rm dummyMeta $ \sc -> do @@ -769,13 +845,14 @@ testClientServerCompressionLvl = testClientMaxReceiveMessageLengthChannelArg :: TestTree testClientMaxReceiveMessageLengthChannelArg = do - testGroup "max receive message length channel arg (client channel)" + testGroup + "max receive message length channel arg (client channel)" [ csTest' "payload size < small bound succeeds" shouldSucceed server - , csTest' "payload size > small bound fails" shouldFail server + , csTest' "payload size > small bound fails" shouldFail server ] where -- The server always sends a 4-byte payload - pay = "four" + pay = "four" server = TestServer (ServerConfig "localhost" 50051 ["/foo"] [] [] [] [] Nothing) $ \s -> do let rm = head (normalMethods s) void $ serverHandleNormalCall s rm mempty $ \sc -> do @@ -798,10 +875,10 @@ testClientMaxReceiveMessageLengthChannelArg = do -- Expect failure when the max recv payload size is set to 3 bytes, and we -- are sent 4. shouldFail = clientMax 3 $ \case - Left (GRPCIOBadStatusCode StatusResourceExhausted _) - -> pure () - rsp - -> clientFail ("Expected failure response, but got: " ++ show rsp) + Left (GRPCIOBadStatusCode StatusResourceExhausted _) -> + pure () + rsp -> + clientFail ("Expected failure response, but got: " ++ show rsp) -------------------------------------------------------------------------------- -- Utilities and helpers @@ -810,26 +887,29 @@ is :: (Eq a, Show a, MonadIO m) => m a -> a -> m () is act x = act >>= liftIO . (@?= x) dummyMeta :: MetadataMap -dummyMeta = [("foo","bar")] +dummyMeta = [("foo", "bar")] dummyResp :: (ByteString, MetadataMap, StatusCode, StatusDetails) dummyResp = ("", mempty, StatusOk, StatusDetails "") -dummyHandler :: ServerCall a - -> IO (ByteString, MetadataMap, StatusCode, StatusDetails) +dummyHandler :: + ServerCall a -> + IO (ByteString, MetadataMap, StatusCode, StatusDetails) dummyHandler _ = return dummyResp -dummyResult' :: StatusDetails - -> IO (ByteString, MetadataMap, StatusCode, StatusDetails) -dummyResult' = return . (mempty, mempty, StatusOk, ) +dummyResult' :: + StatusDetails -> + IO (ByteString, MetadataMap, StatusCode, StatusDetails) +dummyResult' = return . (mempty,mempty,StatusOk,) -nop :: Monad m => a -> m () +nop :: (Monad m) => a -> m () nop = const (return ()) -serverOnlyTest :: TestName - -> ([MethodName],[MethodName],[MethodName],[MethodName]) - -> (Server -> IO ()) - -> TestTree +serverOnlyTest :: + TestName -> + ([MethodName], [MethodName], [MethodName], [MethodName]) -> + (Server -> IO ()) -> + TestTree serverOnlyTest nm ms = testCase ("Server - " ++ nm) . runTestServer . TestServer (serverConf ms) @@ -837,18 +917,19 @@ clientOnlyTest :: TestName -> (Client -> IO ()) -> TestTree clientOnlyTest nm = testCase ("Client - " ++ nm) . runTestClient . stdTestClient -csTest :: TestName - -> (Client -> IO ()) - -> (Server -> IO ()) - -> ([MethodName],[MethodName],[MethodName],[MethodName]) - -> TestTree +csTest :: + TestName -> + (Client -> IO ()) -> + (Server -> IO ()) -> + ([MethodName], [MethodName], [MethodName], [MethodName]) -> + TestTree csTest nm c s ms = csTest' nm (stdTestClient c) (TestServer (serverConf ms) s) csTest' :: TestName -> TestClient -> TestServer -> TestTree csTest' nm tc ts = - testCase ("Client/Server - " ++ nm) - $ void (s `concurrently` c) + testCase ("Client/Server - " ++ nm) $ + void (s `concurrently` c) where -- We use a small delay to give the server a head start c = threadDelay 100000 >> runTestClient tc @@ -858,12 +939,12 @@ csTest' nm tc ts = -- @actual@, or when values differ for matching keys. checkMD :: String -> MetadataMap -> MetadataMap -> Assertion checkMD desc expected actual = - assertEqual desc expected' (actual' `S.intersection` expected') + assertEqual desc expected' (actual' `S.intersection` expected') where expected' = fromList . toList $ expected actual' = fromList . toList $ actual -checkReqRslt :: Show a => (b -> Assertion) -> Either a b -> Assertion +checkReqRslt :: (Show a) => (b -> Assertion) -> Either a b -> Assertion checkReqRslt = either clientFail -- | The consumer which asserts that the next value it consumes is equal to the @@ -871,8 +952,8 @@ checkReqRslt = either clientFail assertConsumeEq :: (Eq a, Show a) => String -> a -> P.Consumer a IO () assertConsumeEq s v = P.lift . assertEqual s v =<< P.await -clientFail :: Show a => a -> Assertion -clientFail = assertFailure . ("Client error: " ++). show +clientFail :: (Show a) => a -> Assertion +clientFail = assertFailure . ("Client error: " ++) . show data TestClient = TestClient ClientConfig (Client -> IO ()) @@ -895,13 +976,16 @@ runTestServer (TestServer conf f) = defServerConf :: ServerConfig defServerConf = ServerConfig "localhost" 50051 [] [] [] [] [] Nothing -serverConf :: ([MethodName],[MethodName],[MethodName],[MethodName]) - -> ServerConfig +serverConf :: + ([MethodName], [MethodName], [MethodName], [MethodName]) -> + ServerConfig serverConf (ns, cs, ss, bs) = - defServerConf {methodsToRegisterNormal = ns, - methodsToRegisterClientStreaming = cs, - methodsToRegisterServerStreaming = ss, - methodsToRegisterBiDiStreaming = bs} + defServerConf + { methodsToRegisterNormal = ns + , methodsToRegisterClientStreaming = cs + , methodsToRegisterServerStreaming = ss + , methodsToRegisterBiDiStreaming = bs + } mgdGRPC :: Managed GRPC mgdGRPC = managed withGRPC diff --git a/core/tests/LowLevelTests/Op.hs b/core/tests/LowLevelTests/Op.hs index 09b42f4d..01f5b8e6 100644 --- a/core/tests/LowLevelTests/Op.hs +++ b/core/tests/LowLevelTests/Op.hs @@ -27,7 +27,7 @@ testCancelFromServer = testCase "Client/Server - client receives server cancellation" $ runSerialTest $ \grpc -> withClientServerUnaryCall grpc $ - \(Client {..}, Server {}, ClientCall {..}, sc@ServerCall {}) -> do + \(Client{..}, Server{}, ClientCall{..}, sc@ServerCall{}) -> do serverCallCancel sc StatusPermissionDenied "TestStatus" clientRes <- runOps unsafeCC clientCQ clientRecvOps case clientRes of @@ -45,10 +45,10 @@ runSerialTest f = withClientServerUnaryCall :: GRPC -> - ( ( Client, - Server, - ClientCall, - ServerCall ByteString + ( ( Client + , Server + , ClientCall + , ServerCall ByteString ) -> IO (Either GRPCIOError a) ) -> @@ -83,21 +83,21 @@ clientConf = ClientConfig "localhost:50051" [] Nothing Nothing clientEmptySendOps :: [Op] clientEmptySendOps = - [ OpSendInitialMetadata mempty, - OpSendMessage "", - OpSendCloseFromClient + [ OpSendInitialMetadata mempty + , OpSendMessage "" + , OpSendCloseFromClient ] clientRecvOps :: [Op] clientRecvOps = - [ OpRecvInitialMetadata, - OpRecvMessage, - OpRecvStatusOnClient + [ OpRecvInitialMetadata + , OpRecvMessage + , OpRecvStatusOnClient ] serverEmptyRecvOps :: [Op] serverEmptyRecvOps = - [ OpSendInitialMetadata mempty, - OpRecvMessage, - OpRecvCloseOnServer + [ OpSendInitialMetadata mempty + , OpRecvMessage + , OpRecvCloseOnServer ] diff --git a/core/tests/Properties.hs b/core/tests/Properties.hs index d184a296..e91b9919 100644 --- a/core/tests/Properties.hs +++ b/core/tests/Properties.hs @@ -1,12 +1,15 @@ -import LowLevelTests -import LowLevelTests.Op -import Test.Tasty -import UnsafeTests +import LowLevelTests +import LowLevelTests.Op +import Test.Tasty +import UnsafeTests main :: IO () -main = defaultMain $ testGroup "GRPC Unit Tests" - [ unsafeTests - , unsafeProperties - , lowLevelOpTests - , lowLevelTests - ] +main = + defaultMain $ + testGroup + "GRPC Unit Tests" + [ unsafeTests + , unsafeProperties + , lowLevelOpTests + , lowLevelTests + ] diff --git a/core/tests/UnsafeTests.hs b/core/tests/UnsafeTests.hs index 5a24cb8d..ddac2703 100644 --- a/core/tests/UnsafeTests.hs +++ b/core/tests/UnsafeTests.hs @@ -1,6 +1,6 @@ {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE TypeApplications #-} {-# LANGUAGE ScopedTypeVariables #-} +{-# LANGUAGE TypeApplications #-} {-# OPTIONS_GHC -fno-warn-orphans #-} module UnsafeTests (unsafeTests, unsafeProperties) where @@ -8,8 +8,8 @@ module UnsafeTests (unsafeTests, unsafeProperties) where import Control.Exception (bracket_) import Control.Monad import qualified Data.ByteString as B +import Data.List.NonEmpty (NonEmpty ((:|))) import qualified Data.Map as M -import Data.List.NonEmpty (NonEmpty((:|))) import Foreign.Marshal.Alloc import Foreign.Storable import GHC.Exts @@ -29,31 +29,35 @@ import Test.Tasty.HUnit as HU (Assertion, testCase, (@?=)) import Test.Tasty.QuickCheck as QC unsafeTests :: TestTree -unsafeTests = testGroup "Unit tests for unsafe C bindings" - [ roundtripSliceUnit "\NULabc\NUL" - , roundtripSliceUnit largeByteString - , roundtripByteBufferUnit largeByteString - , roundtripTimeSpec (TimeSpec 123 123) - , testMetadata - , testMetadataOrdering - , testMetadataOrderingProp - , testNow - , testCreateDestroyMetadata - , testCreateDestroyMetadataKeyVals - , testCreateDestroyDeadline - , testCreateDestroyChannelArgs - , testCreateDestroyClientCreds - , testCreateDestroyServerCreds - ] +unsafeTests = + testGroup + "Unit tests for unsafe C bindings" + [ roundtripSliceUnit "\NULabc\NUL" + , roundtripSliceUnit largeByteString + , roundtripByteBufferUnit largeByteString + , roundtripTimeSpec (TimeSpec 123 123) + , testMetadata + , testMetadataOrdering + , testMetadataOrderingProp + , testNow + , testCreateDestroyMetadata + , testCreateDestroyMetadataKeyVals + , testCreateDestroyDeadline + , testCreateDestroyChannelArgs + , testCreateDestroyClientCreds + , testCreateDestroyServerCreds + ] unsafeProperties :: TestTree -unsafeProperties = testGroup "QuickCheck properties for unsafe C bindings" - [ roundtripSliceQC - , roundtripByteBufferQC - , roundtripMetadataQC - , metadataIsList - , roundtripMetadataOrdering - ] +unsafeProperties = + testGroup + "QuickCheck properties for unsafe C bindings" + [ roundtripSliceQC + , roundtripByteBufferQC + , roundtripMetadataQC + , metadataIsList + , roundtripMetadataOrdering + ] instance Arbitrary B.ByteString where arbitrary = B.pack <$> arbitrary @@ -64,9 +68,10 @@ instance Arbitrary MetadataMap where let key = arbitrary `suchThat` B.notElem 0 ks0 <- listOf key duplicateKeys <- arbitrary - ks <- if duplicateKeys - then (ks0 <>) . concat . replicate 2 <$> listOf1 key - else pure ks0 + ks <- + if duplicateKeys + then (ks0 <>) . concat . replicate 2 <$> listOf1 key + else pure ks0 fromList . zip ks <$> vector (length ks) roundtripMetadataKeyVals :: MetadataMap -> IO MetadataMap @@ -78,19 +83,21 @@ roundtripMetadataKeyVals m = do roundtripMetadataQC :: TestTree roundtripMetadataQC = QC.testProperty "Metadata roundtrip" $ - \m -> QC.ioProperty $ do m' <- roundtripMetadataKeyVals m - return $ m === m' + \m -> QC.ioProperty $ do + m' <- roundtripMetadataKeyVals m + return $ m === m' metadataIsList :: TestTree metadataIsList = QC.testProperty "Metadata IsList instance" $ - \(md :: MetadataMap) -> md == (fromList $ toList md) + \(md :: MetadataMap) -> md == (fromList $ toList md) roundtripMetadataOrdering :: TestTree -roundtripMetadataOrdering = QC.testProperty "Metadata map ordering" $ - QC.ioProperty . checkMetadataOrdering +roundtripMetadataOrdering = + QC.testProperty "Metadata map ordering" $ + QC.ioProperty . checkMetadataOrdering largeByteString :: B.ByteString -largeByteString = B.pack $ take (32*1024*1024) $ cycle [97..99] +largeByteString = B.pack $ take (32 * 1024 * 1024) $ cycle [97 .. 99] roundtripSlice :: B.ByteString -> IO B.ByteString roundtripSlice bs = do @@ -101,8 +108,9 @@ roundtripSlice bs = do roundtripSliceQC :: TestTree roundtripSliceQC = QC.testProperty "Slice roundtrip: QuickCheck" $ - \bs -> QC.ioProperty $ do bs' <- roundtripSlice bs - return $ bs == bs' + \bs -> QC.ioProperty $ do + bs' <- roundtripSlice bs + return $ bs == bs' roundtripSliceUnit :: B.ByteString -> TestTree roundtripSliceUnit bs = testCase "ByteString slice roundtrip" $ do @@ -124,8 +132,9 @@ roundtripByteBuffer bs = do roundtripByteBufferQC :: TestTree roundtripByteBufferQC = QC.testProperty "ByteBuffer roundtrip: QuickCheck" $ - \bs -> QC.ioProperty $ do bs' <- roundtripByteBuffer bs - return $ bs == bs' + \bs -> QC.ioProperty $ do + bs' <- roundtripByteBuffer bs + return $ bs == bs' roundtripByteBufferUnit :: B.ByteString -> TestTree roundtripByteBufferUnit bs = testCase "ByteBuffer roundtrip" $ do @@ -178,15 +187,17 @@ testMetadataOrdering = testCase "Metadata map ordering (simple)" $ do MD.lookupLast "foo" rl @?= Just "bar" testMetadataOrderingProp :: TestTree -testMetadataOrderingProp = testCase "Metadata map ordering prop w/ trivial inputs" $ - mapM_ (checkMetadataOrdering . fromList) - [ [("foo", "bar"), ("fnord", "FNORD"), ("foo", "baz")] - , [("foo", "baz"), ("fnord", "FNORD"), ("foo", "bar")] - ] +testMetadataOrderingProp = + testCase "Metadata map ordering prop w/ trivial inputs" $ + mapM_ + (checkMetadataOrdering . fromList) + [ [("foo", "bar"), ("fnord", "FNORD"), ("foo", "baz")] + , [("foo", "baz"), ("fnord", "FNORD"), ("foo", "bar")] + ] checkMetadataOrdering :: MetadataMap -> Assertion checkMetadataOrdering md0 = do - let ikvps = toList md0 `zip` [0..] + let ikvps = toList md0 `zip` [0 ..] let ok md = unMap md @?= M.unionsWith (<>) [M.singleton k [v] | ((k, v), _i) <- ikvps] ok md0 md1 <- do @@ -227,21 +238,32 @@ testCreateDestroyDeadline = testCase "Create/destroy deadline" $ do grpc $ withDeadlineSeconds 10 $ const $ return () testCreateDestroyChannelArgs :: TestTree -testCreateDestroyChannelArgs = testCase "Create/destroy channel args" $ - grpc $ withChannelArgs [CompressionAlgArg GrpcCompressDeflate] $ - const $ return () +testCreateDestroyChannelArgs = + testCase "Create/destroy channel args" $ + grpc $ + withChannelArgs [CompressionAlgArg GrpcCompressDeflate] $ + const $ + return () testCreateDestroyClientCreds :: TestTree -testCreateDestroyClientCreds = testCase "Create/destroy client credentials" $ - grpc $ withChannelCredentials Nothing Nothing Nothing $ const $ return () +testCreateDestroyClientCreds = + testCase "Create/destroy client credentials" $ + grpc $ + withChannelCredentials Nothing Nothing Nothing $ + const $ + return () testCreateDestroyServerCreds :: TestTree -testCreateDestroyServerCreds = testCase "Create/destroy server credentials" $ - grpc $ withServerCredentials Nothing - "tests/ssl/testServerKey.pem" - "tests/ssl/testServerCert.pem" - SslDontRequestClientCertificate - $ const $ return () +testCreateDestroyServerCreds = + testCase "Create/destroy server credentials" + $ grpc + $ withServerCredentials + Nothing + "tests/ssl/testServerKey.pem" + "tests/ssl/testServerCert.pem" + SslDontRequestClientCertificate + $ const + $ return () assertCqEventComplete :: Event -> IO () assertCqEventComplete e = do diff --git a/examples/echo/echo-hs/Echo.hs b/examples/echo/echo-hs/Echo.hs index beb8abc5..4d915a22 100644 --- a/examples/echo/echo-hs/Echo.hs +++ b/examples/echo/echo-hs/Echo.hs @@ -11,23 +11,15 @@ -- | Generated by Haskell protocol buffer compiler. DO NOT EDIT! module Echo where -import qualified Prelude as Hs -import qualified Proto3.Suite.Class as HsProtobuf -import qualified Proto3.Suite.DotProto as HsProtobufAST -import qualified Proto3.Suite.JSONPB as HsJSONPB -import Proto3.Suite.JSONPB ((.=), (.:)) -import qualified Proto3.Suite.Types as HsProtobuf -import qualified Proto3.Wire as HsProtobuf -import qualified Proto3.Wire.Decode as HsProtobuf - (Parser, RawField) + +import Control.Applicative ((<$>), (<*>), (<|>)) import qualified Control.Applicative as Hs -import Control.Applicative ((<*>), (<|>), (<$>)) import qualified Control.DeepSeq as Hs import qualified Control.Monad as Hs import qualified Data.ByteString as Hs import qualified Data.Coerce as Hs import qualified Data.Int as Hs (Int16, Int32, Int64) -import qualified Data.List.NonEmpty as Hs (NonEmpty(..)) +import qualified Data.List.NonEmpty as Hs (NonEmpty (..)) import qualified Data.Map as Hs (Map, mapKeysMonotonic) import qualified Data.Proxy as Proxy import qualified Data.String as Hs (fromString) @@ -36,192 +28,289 @@ import qualified Data.Vector as Hs (Vector) import qualified Data.Word as Hs (Word16, Word32, Word64) import qualified GHC.Enum as Hs import qualified GHC.Generics as Hs -import qualified Google.Protobuf.Wrappers.Polymorphic as HsProtobuf - (Wrapped(..)) -import qualified Unsafe.Coerce as Hs -import Network.GRPC.HighLevel.Generated as HsGRPC +import qualified Google.Protobuf.Wrappers.Polymorphic as HsProtobuf ( + Wrapped (..), + ) import Network.GRPC.HighLevel.Client as HsGRPC +import Network.GRPC.HighLevel.Generated as HsGRPC import Network.GRPC.HighLevel.Server as HsGRPC hiding (serverLoop) -import Network.GRPC.HighLevel.Server.Unregistered as HsGRPC - (serverLoop) - -data Echo request response = Echo{echoDoEcho :: - request 'HsGRPC.Normal Echo.EchoRequest Echo.EchoResponse -> - Hs.IO (response 'HsGRPC.Normal Echo.EchoResponse)} - deriving Hs.Generic - +import Network.GRPC.HighLevel.Server.Unregistered as HsGRPC ( + serverLoop, + ) +import qualified Proto3.Suite.Class as HsProtobuf +import qualified Proto3.Suite.DotProto as HsProtobufAST +import Proto3.Suite.JSONPB ((.:), (.=)) +import qualified Proto3.Suite.JSONPB as HsJSONPB +import qualified Proto3.Suite.Types as HsProtobuf +import qualified Proto3.Wire as HsProtobuf +import qualified Proto3.Wire.Decode as HsProtobuf ( + Parser, + RawField, + ) +import qualified Unsafe.Coerce as Hs +import qualified Prelude as Hs + +data Echo request response = Echo + { echoDoEcho :: + request 'HsGRPC.Normal Echo.EchoRequest Echo.EchoResponse -> + Hs.IO (response 'HsGRPC.Normal Echo.EchoResponse) + } + deriving (Hs.Generic) + echoServer :: - Echo HsGRPC.ServerRequest HsGRPC.ServerResponse -> - HsGRPC.ServiceOptions -> Hs.IO () -echoServer Echo{echoDoEcho = echoDoEcho} - (ServiceOptions serverHost serverPort useCompression - userAgentPrefix userAgentSuffix initialMetadata sslConfig logger - serverMaxReceiveMessageLength serverMaxMetadataSize) - = (HsGRPC.serverLoop - HsGRPC.defaultOptions{HsGRPC.optNormalHandlers = - [(HsGRPC.UnaryHandler (HsGRPC.MethodName "/echo.Echo/DoEcho") - (HsGRPC.convertGeneratedServerHandler echoDoEcho))], - HsGRPC.optClientStreamHandlers = [], - HsGRPC.optServerStreamHandlers = [], - HsGRPC.optBiDiStreamHandlers = [], optServerHost = serverHost, - optServerPort = serverPort, optUseCompression = useCompression, - optUserAgentPrefix = userAgentPrefix, - optUserAgentSuffix = userAgentSuffix, - optInitialMetadata = initialMetadata, optSSLConfig = sslConfig, - optLogger = logger, - optMaxReceiveMessageLength = serverMaxReceiveMessageLength, - optMaxMetadataSize = serverMaxMetadataSize}) - + Echo HsGRPC.ServerRequest HsGRPC.ServerResponse -> + HsGRPC.ServiceOptions -> + Hs.IO () +echoServer + Echo{echoDoEcho = echoDoEcho} + ( ServiceOptions + serverHost + serverPort + useCompression + userAgentPrefix + userAgentSuffix + initialMetadata + sslConfig + logger + serverMaxReceiveMessageLength + serverMaxMetadataSize + ) = + ( HsGRPC.serverLoop + HsGRPC.defaultOptions + { HsGRPC.optNormalHandlers = + [ ( HsGRPC.UnaryHandler + (HsGRPC.MethodName "/echo.Echo/DoEcho") + (HsGRPC.convertGeneratedServerHandler echoDoEcho) + ) + ] + , HsGRPC.optClientStreamHandlers = [] + , HsGRPC.optServerStreamHandlers = [] + , HsGRPC.optBiDiStreamHandlers = [] + , optServerHost = serverHost + , optServerPort = serverPort + , optUseCompression = useCompression + , optUserAgentPrefix = userAgentPrefix + , optUserAgentSuffix = userAgentSuffix + , optInitialMetadata = initialMetadata + , optSSLConfig = sslConfig + , optLogger = logger + , optMaxReceiveMessageLength = serverMaxReceiveMessageLength + , optMaxMetadataSize = serverMaxMetadataSize + } + ) + echoClient :: - HsGRPC.Client -> - Hs.IO (Echo HsGRPC.ClientRequest HsGRPC.ClientResult) -echoClient client - = (Hs.pure Echo) <*> - ((Hs.pure (HsGRPC.clientRequest client)) <*> - (HsGRPC.clientRegisterMethod client - (HsGRPC.MethodName "/echo.Echo/DoEcho"))) - -newtype EchoRequest = EchoRequest{echoRequestMessage :: Hs.Text} - deriving (Hs.Show, Hs.Eq, Hs.Ord, Hs.Generic) - + HsGRPC.Client -> + Hs.IO (Echo HsGRPC.ClientRequest HsGRPC.ClientResult) +echoClient client = + (Hs.pure Echo) + <*> ( (Hs.pure (HsGRPC.clientRequest client)) + <*> ( HsGRPC.clientRegisterMethod + client + (HsGRPC.MethodName "/echo.Echo/DoEcho") + ) + ) + +newtype EchoRequest = EchoRequest {echoRequestMessage :: Hs.Text} + deriving (Hs.Show, Hs.Eq, Hs.Ord, Hs.Generic) + instance Hs.NFData EchoRequest - + instance HsProtobuf.Named EchoRequest where - nameOf _ = (Hs.fromString "EchoRequest") - + nameOf _ = (Hs.fromString "EchoRequest") + instance HsProtobuf.HasDefault EchoRequest - + instance HsProtobuf.Message EchoRequest where - encodeMessage _ - EchoRequest{echoRequestMessage = echoRequestMessage} - = (Hs.mconcat - [(HsProtobuf.encodeMessageField (HsProtobuf.FieldNumber 1) - (Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) - (echoRequestMessage)))]) - decodeMessage _ - = (Hs.pure EchoRequest) <*> - (HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) - (HsProtobuf.at HsProtobuf.decodeMessageField - (HsProtobuf.FieldNumber 1))) - dotProto _ - = [(HsProtobufAST.DotProtoField (HsProtobuf.FieldNumber 1) - (HsProtobufAST.Prim HsProtobufAST.String) - (HsProtobufAST.Single "message") - [] - "")] - + encodeMessage + _ + EchoRequest{echoRequestMessage = echoRequestMessage} = + ( Hs.mconcat + [ ( HsProtobuf.encodeMessageField + (HsProtobuf.FieldNumber 1) + ( Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) + (echoRequestMessage) + ) + ) + ] + ) + decodeMessage _ = + (Hs.pure EchoRequest) + <*> ( HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) + ( HsProtobuf.at + HsProtobuf.decodeMessageField + (HsProtobuf.FieldNumber 1) + ) + ) + dotProto _ = + [ ( HsProtobufAST.DotProtoField + (HsProtobuf.FieldNumber 1) + (HsProtobufAST.Prim HsProtobufAST.String) + (HsProtobufAST.Single "message") + [] + "" + ) + ] + instance HsJSONPB.ToJSONPB EchoRequest where - toJSONPB (EchoRequest f1) - = (HsJSONPB.object - ["message" .= - (Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) (f1))]) - toEncodingPB (EchoRequest f1) - = (HsJSONPB.pairs - ["message" .= - (Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) (f1))]) - + toJSONPB (EchoRequest f1) = + ( HsJSONPB.object + [ "message" + .= (Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) (f1)) + ] + ) + toEncodingPB (EchoRequest f1) = + ( HsJSONPB.pairs + [ "message" + .= (Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) (f1)) + ] + ) + instance HsJSONPB.FromJSONPB EchoRequest where - parseJSONPB - = (HsJSONPB.withObject "EchoRequest" - (\ obj -> - (Hs.pure EchoRequest) <*> - (HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) - (obj .: "message")))) - + parseJSONPB = + ( HsJSONPB.withObject + "EchoRequest" + ( \obj -> + (Hs.pure EchoRequest) + <*> ( HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) + (obj .: "message") + ) + ) + ) + instance HsJSONPB.ToJSON EchoRequest where - toJSON = HsJSONPB.toAesonValue - toEncoding = HsJSONPB.toAesonEncoding - + toJSON = HsJSONPB.toAesonValue + toEncoding = HsJSONPB.toAesonEncoding + instance HsJSONPB.FromJSON EchoRequest where - parseJSON = HsJSONPB.parseJSONPB - + parseJSON = HsJSONPB.parseJSONPB + instance HsJSONPB.ToSchema EchoRequest where - declareNamedSchema _ - = do let declare_message = HsJSONPB.declareSchemaRef - echoRequestMessage <- declare_message Proxy.Proxy - let _ = Hs.pure EchoRequest <*> - (HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) - (HsJSONPB.asProxy declare_message)) - Hs.return - (HsJSONPB.NamedSchema{HsJSONPB._namedSchemaName = - Hs.Just "EchoRequest", - HsJSONPB._namedSchemaSchema = - Hs.mempty{HsJSONPB._schemaParamSchema = - Hs.mempty{HsJSONPB._paramSchemaType = - Hs.Just HsJSONPB.SwaggerObject}, - HsJSONPB._schemaProperties = - HsJSONPB.insOrdFromList - [("message", echoRequestMessage)]}}) - -newtype EchoResponse = EchoResponse{echoResponseMessage :: Hs.Text} - deriving (Hs.Show, Hs.Eq, Hs.Ord, Hs.Generic) - + declareNamedSchema _ = + do + let declare_message = HsJSONPB.declareSchemaRef + echoRequestMessage <- declare_message Proxy.Proxy + let _ = + Hs.pure EchoRequest + <*> ( HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) + (HsJSONPB.asProxy declare_message) + ) + Hs.return + ( HsJSONPB.NamedSchema + { HsJSONPB._namedSchemaName = + Hs.Just "EchoRequest" + , HsJSONPB._namedSchemaSchema = + Hs.mempty + { HsJSONPB._schemaParamSchema = + Hs.mempty + { HsJSONPB._paramSchemaType = + Hs.Just HsJSONPB.SwaggerObject + } + , HsJSONPB._schemaProperties = + HsJSONPB.insOrdFromList + [("message", echoRequestMessage)] + } + } + ) + +newtype EchoResponse = EchoResponse {echoResponseMessage :: Hs.Text} + deriving (Hs.Show, Hs.Eq, Hs.Ord, Hs.Generic) + instance Hs.NFData EchoResponse - + instance HsProtobuf.Named EchoResponse where - nameOf _ = (Hs.fromString "EchoResponse") - + nameOf _ = (Hs.fromString "EchoResponse") + instance HsProtobuf.HasDefault EchoResponse - + instance HsProtobuf.Message EchoResponse where - encodeMessage _ - EchoResponse{echoResponseMessage = echoResponseMessage} - = (Hs.mconcat - [(HsProtobuf.encodeMessageField (HsProtobuf.FieldNumber 1) - (Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) - (echoResponseMessage)))]) - decodeMessage _ - = (Hs.pure EchoResponse) <*> - (HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) - (HsProtobuf.at HsProtobuf.decodeMessageField - (HsProtobuf.FieldNumber 1))) - dotProto _ - = [(HsProtobufAST.DotProtoField (HsProtobuf.FieldNumber 1) - (HsProtobufAST.Prim HsProtobufAST.String) - (HsProtobufAST.Single "message") - [] - "")] - + encodeMessage + _ + EchoResponse{echoResponseMessage = echoResponseMessage} = + ( Hs.mconcat + [ ( HsProtobuf.encodeMessageField + (HsProtobuf.FieldNumber 1) + ( Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) + (echoResponseMessage) + ) + ) + ] + ) + decodeMessage _ = + (Hs.pure EchoResponse) + <*> ( HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) + ( HsProtobuf.at + HsProtobuf.decodeMessageField + (HsProtobuf.FieldNumber 1) + ) + ) + dotProto _ = + [ ( HsProtobufAST.DotProtoField + (HsProtobuf.FieldNumber 1) + (HsProtobufAST.Prim HsProtobufAST.String) + (HsProtobufAST.Single "message") + [] + "" + ) + ] + instance HsJSONPB.ToJSONPB EchoResponse where - toJSONPB (EchoResponse f1) - = (HsJSONPB.object - ["message" .= - (Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) (f1))]) - toEncodingPB (EchoResponse f1) - = (HsJSONPB.pairs - ["message" .= - (Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) (f1))]) - + toJSONPB (EchoResponse f1) = + ( HsJSONPB.object + [ "message" + .= (Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) (f1)) + ] + ) + toEncodingPB (EchoResponse f1) = + ( HsJSONPB.pairs + [ "message" + .= (Hs.coerce @(Hs.Text) @(HsProtobuf.String Hs.Text) (f1)) + ] + ) + instance HsJSONPB.FromJSONPB EchoResponse where - parseJSONPB - = (HsJSONPB.withObject "EchoResponse" - (\ obj -> - (Hs.pure EchoResponse) <*> - (HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) - (obj .: "message")))) - + parseJSONPB = + ( HsJSONPB.withObject + "EchoResponse" + ( \obj -> + (Hs.pure EchoResponse) + <*> ( HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) + (obj .: "message") + ) + ) + ) + instance HsJSONPB.ToJSON EchoResponse where - toJSON = HsJSONPB.toAesonValue - toEncoding = HsJSONPB.toAesonEncoding - + toJSON = HsJSONPB.toAesonValue + toEncoding = HsJSONPB.toAesonEncoding + instance HsJSONPB.FromJSON EchoResponse where - parseJSON = HsJSONPB.parseJSONPB - -instance HsJSONPB.ToSchema EchoResponse where - declareNamedSchema _ - = do let declare_message = HsJSONPB.declareSchemaRef - echoResponseMessage <- declare_message Proxy.Proxy - let _ = Hs.pure EchoResponse <*> - (HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) - (HsJSONPB.asProxy declare_message)) - Hs.return - (HsJSONPB.NamedSchema{HsJSONPB._namedSchemaName = - Hs.Just "EchoResponse", - HsJSONPB._namedSchemaSchema = - Hs.mempty{HsJSONPB._schemaParamSchema = - Hs.mempty{HsJSONPB._paramSchemaType = - Hs.Just HsJSONPB.SwaggerObject}, - HsJSONPB._schemaProperties = - HsJSONPB.insOrdFromList - [("message", echoResponseMessage)]}}) + parseJSON = HsJSONPB.parseJSONPB +instance HsJSONPB.ToSchema EchoResponse where + declareNamedSchema _ = + do + let declare_message = HsJSONPB.declareSchemaRef + echoResponseMessage <- declare_message Proxy.Proxy + let _ = + Hs.pure EchoResponse + <*> ( HsProtobuf.coerceOver @(HsProtobuf.String Hs.Text) @(Hs.Text) + (HsJSONPB.asProxy declare_message) + ) + Hs.return + ( HsJSONPB.NamedSchema + { HsJSONPB._namedSchemaName = + Hs.Just "EchoResponse" + , HsJSONPB._namedSchemaSchema = + Hs.mempty + { HsJSONPB._schemaParamSchema = + Hs.mempty + { HsJSONPB._paramSchemaType = + Hs.Just HsJSONPB.SwaggerObject + } + , HsJSONPB._schemaProperties = + HsJSONPB.insOrdFromList + [("message", echoResponseMessage)] + } + } + ) diff --git a/examples/echo/echo-hs/EchoClient.hs b/examples/echo/echo-hs/EchoClient.hs index b12bba4c..ec856579 100644 --- a/examples/echo/echo-hs/EchoClient.hs +++ b/examples/echo/echo-hs/EchoClient.hs @@ -1,45 +1,49 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TypeOperators #-} -import Control.Monad -import Data.ByteString (ByteString) -import Data.Maybe (fromMaybe) -import qualified Data.Text.Lazy as TL -import Echo -import Network.GRPC.HighLevel.Client -import Network.GRPC.LowLevel -import Network.GRPC.LowLevel.Call (Endpoint(..)) -import Options.Generic -import Prelude hiding (FilePath) +import Control.Monad +import Data.ByteString (ByteString) +import Data.Maybe (fromMaybe) +import qualified Data.Text.Lazy as TL +import Echo +import Network.GRPC.HighLevel.Client +import Network.GRPC.LowLevel +import Network.GRPC.LowLevel.Call (Endpoint (..)) +import Options.Generic +import Prelude hiding (FilePath) data Args = Args - { endpoint :: Maybe ByteString "grpc endpoint (default \"localhost:50051\")" - , payload :: Maybe TL.Text "string to echo (default \"hullo!\")" - } deriving (Generic, Show) + { endpoint :: Maybe ByteString "grpc endpoint (default \"localhost:50051\")" + , payload :: Maybe TL.Text "string to echo (default \"hullo!\")" + } + deriving (Generic, Show) instance ParseRecord Args main :: IO () main = do Args{..} <- getRecord "Runs the echo client" let - pay = fromMaybe "hullo!" . unHelpful $ payload - rqt = EchoRequest pay + pay = fromMaybe "hullo!" . unHelpful $ payload + rqt = EchoRequest pay expected = EchoResponse pay - cfg = ClientConfig - (Endpoint . fromMaybe "localhost:50051" . unHelpful $ endpoint) - [] Nothing Nothing + cfg = + ClientConfig + (Endpoint . fromMaybe "localhost:50051" . unHelpful $ endpoint) + [] + Nothing + Nothing withGRPC $ \g -> withClient g cfg $ \c -> do Echo{..} <- echoClient c echoDoEcho (ClientNormalRequest rqt 5 mempty) >>= \case ClientNormalResponse rsp _ _ StatusOk _ - | rsp == expected -> return () - | otherwise -> fail $ "Got unexpected response: '" ++ show rsp ++ "', expected: '" ++ show expected ++ "'" + | rsp == expected -> return () + | otherwise -> fail $ "Got unexpected response: '" ++ show rsp ++ "', expected: '" ++ show expected ++ "'" ClientNormalResponse _ _ _ st _ -> fail $ "Got unexpected status " ++ show st ++ " from call, expecting StatusOk" - ClientErrorResponse e -> fail $ "Got client error: " ++ show e + ClientErrorResponse e -> fail $ "Got client error: " ++ show e putStrLn $ "echo-client success: sent " ++ show pay ++ ", got " ++ show pay diff --git a/examples/echo/echo-hs/EchoServer.hs b/examples/echo/echo-hs/EchoServer.hs index 653d727c..c91b4c7b 100644 --- a/examples/echo/echo-hs/EchoServer.hs +++ b/examples/echo/echo-hs/EchoServer.hs @@ -1,41 +1,48 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RecordWildCards #-} -{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TypeOperators #-} -import Data.ByteString (ByteString) -import Data.Maybe (fromMaybe) -import Network.GRPC.HighLevel.Generated (GRPCMethodType (..), - Host (..), Port (..), - ServerRequest (..), - ServerResponse (..), - StatusCode (..), - defaultServiceOptions, - serverHost, serverPort) -import Options.Generic +import Data.ByteString (ByteString) +import Data.Maybe (fromMaybe) +import Network.GRPC.HighLevel.Generated ( + GRPCMethodType (..), + Host (..), + Port (..), + ServerRequest (..), + ServerResponse (..), + StatusCode (..), + defaultServiceOptions, + serverHost, + serverPort, + ) +import Options.Generic -import Echo +import Echo data Args = Args { bind :: Maybe ByteString "grpc endpoint hostname (default \"localhost\")" - , port :: Maybe Int "grpc endpoint port (default 50051)" - } deriving (Generic, Show) + , port :: Maybe Int "grpc endpoint port (default 50051)" + } + deriving (Generic, Show) instance ParseRecord Args -doEcho :: ServerRequest 'Normal EchoRequest EchoResponse - -> IO (ServerResponse 'Normal EchoResponse) +doEcho :: + ServerRequest 'Normal EchoRequest EchoResponse -> + IO (ServerResponse 'Normal EchoResponse) doEcho (ServerNormalRequest _meta (EchoRequest pay)) = do return (ServerNormalResponse (EchoResponse pay) mempty StatusOk "") main :: IO () main = do Args{..} <- getRecord "Runs the echo service" - let opts = defaultServiceOptions - { serverHost = Host . fromMaybe "localhost" . unHelpful $ bind - , serverPort = Port . fromMaybe 50051 . unHelpful $ port - } - echoServer Echo{ echoDoEcho = doEcho } opts + let opts = + defaultServiceOptions + { serverHost = Host . fromMaybe "localhost" . unHelpful $ bind + , serverPort = Port . fromMaybe 50051 . unHelpful $ port + } + echoServer Echo{echoDoEcho = doEcho} opts diff --git a/examples/hellos/hellos-client/Main.hs b/examples/hellos/hellos-client/Main.hs index 4d0968d8..8e409343 100644 --- a/examples/hellos/hellos-client/Main.hs +++ b/examples/hellos/hellos-client/Main.hs @@ -1,34 +1,34 @@ -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} -{-# LANGUAGE ViewPatterns #-} -{-# OPTIONS_GHC -fno-warn-unused-binds #-} +{-# LANGUAGE ViewPatterns #-} +{-# OPTIONS_GHC -fno-warn-unused-binds #-} -import Control.Concurrent.Async -import Control.Monad -import qualified Data.ByteString.Lazy as BL -import Data.Function -import qualified Data.Text as T -import Data.Word -import GHC.Generics (Generic) -import Network.GRPC.LowLevel -import Proto3.Suite.Class +import Control.Concurrent.Async +import Control.Monad +import qualified Data.ByteString.Lazy as BL +import Data.Function +import qualified Data.Text as T +import Data.Word +import GHC.Generics (Generic) +import Network.GRPC.LowLevel +import Proto3.Suite.Class helloSS, helloCS, helloBi :: MethodName helloSS = MethodName "/hellos.Hellos/HelloSS" helloCS = MethodName "/hellos.Hellos/HelloCS" helloBi = MethodName "/hellos.Hellos/HelloBi" -data SSRqt = SSRqt { ssName :: T.Text, ssNumReplies :: Word32 } deriving (Show, Eq, Ord, Generic) +data SSRqt = SSRqt {ssName :: T.Text, ssNumReplies :: Word32} deriving (Show, Eq, Ord, Generic) instance Message SSRqt -data SSRpy = SSRpy { ssGreeting :: T.Text } deriving (Show, Eq, Ord, Generic) +data SSRpy = SSRpy {ssGreeting :: T.Text} deriving (Show, Eq, Ord, Generic) instance Message SSRpy -data CSRqt = CSRqt { csMessage :: T.Text } deriving (Show, Eq, Ord, Generic) +data CSRqt = CSRqt {csMessage :: T.Text} deriving (Show, Eq, Ord, Generic) instance Message CSRqt -data CSRpy = CSRpy { csNumRequests :: Word32 } deriving (Show, Eq, Ord, Generic) +data CSRpy = CSRpy {csNumRequests :: Word32} deriving (Show, Eq, Ord, Generic) instance Message CSRpy -data BiRqtRpy = BiRqtRpy { biMessage :: T.Text } deriving (Show, Eq, Ord, Generic) +data BiRqtRpy = BiRqtRpy {biMessage :: T.Text} deriving (Show, Eq, Ord, Generic) instance Message BiRqtRpy expect :: (Eq a, MonadFail m, Show a) => String -> a -> a -> m () @@ -39,41 +39,43 @@ expect ctx ex got doHelloSS :: Client -> Int -> IO () doHelloSS c n = do rm <- clientRegisterMethodServerStreaming c helloSS - let pay = SSRqt "server streaming mode" (fromIntegral n) - enc = BL.toStrict . toLazyByteString $ pay + let pay = SSRqt "server streaming mode" (fromIntegral n) + enc = BL.toStrict . toLazyByteString $ pay err desc e = fail $ "doHelloSS: " ++ desc ++ " error: " ++ show e eea <- clientReader c rm n enc mempty $ \_cc _md recv -> do - n' <- flip fix (0::Int) $ \go i -> recv >>= \case - Left e -> err "recv" e - Right Nothing -> return i - Right (Just bs) -> case fromByteString bs of - Left e -> err "decoding" e - Right r -> expect "doHelloSS/rpy" expay (ssGreeting r) >> go (i+1) + n' <- flip fix (0 :: Int) $ \go i -> + recv >>= \case + Left e -> err "recv" e + Right Nothing -> return i + Right (Just bs) -> case fromByteString bs of + Left e -> err "decoding" e + Right r -> expect "doHelloSS/rpy" expay (ssGreeting r) >> go (i + 1) expect "doHelloSS/cnt" n n' case eea of - Left e -> err "clientReader" e + Left e -> err "clientReader" e Right (_, st, _) | st /= StatusOk -> fail "clientReader: non-OK status" - | otherwise -> putStrLn "doHelloSS: RPC successful" + | otherwise -> putStrLn "doHelloSS: RPC successful" where - expay = "Hello there, server streaming mode!" + expay = "Hello there, server streaming mode!" doHelloCS :: Client -> Int -> IO () doHelloCS c n = do - rm <- clientRegisterMethodClientStreaming c helloCS + rm <- clientRegisterMethodClientStreaming c helloCS let pay = CSRqt "client streaming payload" enc = BL.toStrict . toLazyByteString $ pay eea <- clientWriter c rm n mempty $ \send -> - replicateM_ n $ send enc >>= \case - Left e -> fail $ "doHelloCS: send error: " ++ show e - Right{} -> return () + replicateM_ n $ + send enc >>= \case + Left e -> fail $ "doHelloCS: send error: " ++ show e + Right{} -> return () case eea of - Left e -> fail $ "clientWriter error: " ++ show e + Left e -> fail $ "clientWriter error: " ++ show e Right (Nothing, _, _, _, _) -> fail "clientWriter error: no reply payload" Right (Just bs, _init, _trail, st, _dtls) | st /= StatusOk -> fail "clientWriter: non-OK status" | otherwise -> case fromByteString bs of - Left e -> fail $ "Decoding error: " ++ show e + Left e -> fail $ "Decoding error: " ++ show e Right dec -> do expect "doHelloCS/cnt" (fromIntegral n) (csNumRequests dec) putStrLn "doHelloCS: RPC successful" @@ -81,29 +83,31 @@ doHelloCS c n = do doHelloBi :: Client -> Int -> IO () doHelloBi c n = do rm <- clientRegisterMethodBiDiStreaming c helloBi - let pay = BiRqtRpy "bidi payload" - enc = BL.toStrict . toLazyByteString $ pay + let pay = BiRqtRpy "bidi payload" + enc = BL.toStrict . toLazyByteString $ pay err desc e = fail $ "doHelloBi: " ++ desc ++ " error: " ++ show e eea <- clientRW c rm n mempty $ \_cc _getMD recv send writesDone -> do -- perform n writes on a worker thread thd <- async $ do - replicateM_ n $ send enc >>= \case - Left e -> err "send" e - _ -> return () + replicateM_ n $ + send enc >>= \case + Left e -> err "send" e + _ -> return () writesDone >>= \case Left e -> err "writesDone" e - _ -> return () + _ -> return () -- perform reads on this thread until the stream is terminated -- emd <- getMD; putStrLn ("getMD result: " ++ show emd) - fix $ \go -> recv >>= \case - Left e -> err "recv" e - Right Nothing -> return () - Right (Just bs) -> case fromByteString bs of - Left e -> err "decoding" e - Right r -> when (r /= pay) (fail "Reply payload mismatch") >> go + fix $ \go -> + recv >>= \case + Left e -> err "recv" e + Right Nothing -> return () + Right (Just bs) -> case fromByteString bs of + Left e -> err "decoding" e + Right r -> when (r /= pay) (fail "Reply payload mismatch") >> go wait thd case eea of - Left e -> err "clientRW'" e + Left e -> err "clientRW'" e Right (_, st, _) -> do when (st /= StatusOk) $ fail $ "clientRW: non-OK status: " ++ show st putStrLn "doHelloBi: RPC successful" diff --git a/examples/hellos/hellos-server/Main.hs b/examples/hellos/hellos-server/Main.hs index 45129384..2f80f92e 100644 --- a/examples/hellos/hellos-server/Main.hs +++ b/examples/hellos/hellos-server/Main.hs @@ -1,35 +1,35 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE DeriveGeneric #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedLists #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} {-# OPTIONS_GHC -fno-warn-missing-signatures #-} -{-# OPTIONS_GHC -fno-warn-unused-binds #-} +{-# OPTIONS_GHC -fno-warn-unused-binds #-} -import Control.Monad -import Data.Function (fix) -import qualified Data.Text as T -import Data.Word -import GHC.Generics (Generic) -import Network.GRPC.HighLevel.Server +import Control.Monad +import Data.Function (fix) +import qualified Data.Text as T +import Data.Word +import GHC.Generics (Generic) +import Network.GRPC.HighLevel.Server import qualified Network.GRPC.HighLevel.Server.Unregistered as U -import Network.GRPC.LowLevel -import Proto3.Suite.Class +import Network.GRPC.LowLevel +import Proto3.Suite.Class serverMeta :: MetadataMap serverMeta = [("test_meta", "test_meta_value")] -data SSRqt = SSRqt { ssName :: T.Text, ssNumReplies :: Word32 } deriving (Show, Eq, Ord, Generic) +data SSRqt = SSRqt {ssName :: T.Text, ssNumReplies :: Word32} deriving (Show, Eq, Ord, Generic) instance Message SSRqt -data SSRpy = SSRpy { ssGreeting :: T.Text } deriving (Show, Eq, Ord, Generic) +data SSRpy = SSRpy {ssGreeting :: T.Text} deriving (Show, Eq, Ord, Generic) instance Message SSRpy -data CSRqt = CSRqt { csMessage :: T.Text } deriving (Show, Eq, Ord, Generic) +data CSRqt = CSRqt {csMessage :: T.Text} deriving (Show, Eq, Ord, Generic) instance Message CSRqt -data CSRpy = CSRpy { csNumRequests :: Word32 } deriving (Show, Eq, Ord, Generic) +data CSRpy = CSRpy {csNumRequests :: Word32} deriving (Show, Eq, Ord, Generic) instance Message CSRpy -data BiRqtRpy = BiRqtRpy { biMessage :: T.Text } deriving (Show, Eq, Ord, Generic) +data BiRqtRpy = BiRqtRpy {biMessage :: T.Text} deriving (Show, Eq, Ord, Generic) instance Message BiRqtRpy expect :: (Eq a, MonadFail m, Show a) => String -> a -> a -> m () @@ -43,37 +43,38 @@ helloSS = ServerStreamHandler "/hellos.Hellos/HelloSS" $ \sc send -> do replicateM_ (fromIntegral ssNumReplies) $ do eea <- send $ SSRpy $ "Hello there, " <> ssName <> "!" case eea of - Left e -> fail $ "helloSS error: " ++ show e + Left e -> fail $ "helloSS error: " ++ show e Right{} -> return () return (serverMeta, StatusOk, StatusDetails "helloSS response details") helloCS :: Handler 'ClientStreaming helloCS = ClientStreamHandler "/hellos.Hellos/HelloCS" $ \_ recv -> flip fix 0 $ \go n -> recv >>= \case - Left e -> fail $ "helloCS error: " ++ show e - Right Nothing -> return (Just (CSRpy n), mempty, StatusOk, StatusDetails "helloCS details") + Left e -> fail $ "helloCS error: " ++ show e + Right Nothing -> return (Just (CSRpy n), mempty, StatusOk, StatusDetails "helloCS details") Right (Just rqt) -> do expect "helloCS" "client streaming payload" (csMessage rqt) - go (n+1) + go (n + 1) helloBi :: Handler 'BiDiStreaming helloBi = BiDiStreamHandler "/hellos.Hellos/HelloBi" $ \_ recv send -> fix $ \go -> recv >>= \case - Left e -> fail $ "helloBi recv error: " ++ show e - Right Nothing -> return (mempty, StatusOk, StatusDetails "helloBi details") + Left e -> fail $ "helloBi recv error: " ++ show e + Right Nothing -> return (mempty, StatusOk, StatusDetails "helloBi details") Right (Just rqt) -> do expect "helloBi" "bidi payload" (biMessage rqt) send rqt >>= \case Left e -> fail $ "helloBi send error: " ++ show e - _ -> go + _ -> go highlevelMainUnregistered :: IO () highlevelMainUnregistered = - U.serverLoop defaultOptions{ - optServerStreamHandlers = [helloSS] - , optClientStreamHandlers = [helloCS] - , optBiDiStreamHandlers = [helloBi] - } + U.serverLoop + defaultOptions + { optServerStreamHandlers = [helloSS] + , optClientStreamHandlers = [helloCS] + , optBiDiStreamHandlers = [helloBi] + } main :: IO () main = highlevelMainUnregistered diff --git a/examples/tutorial/Arithmetic.hs b/examples/tutorial/Arithmetic.hs index d8b1f349..72f052ec 100644 --- a/examples/tutorial/Arithmetic.hs +++ b/examples/tutorial/Arithmetic.hs @@ -11,23 +11,15 @@ -- | Generated by Haskell protocol buffer compiler. DO NOT EDIT! module Arithmetic where -import qualified Prelude as Hs -import qualified Proto3.Suite.Class as HsProtobuf -import qualified Proto3.Suite.DotProto as HsProtobufAST -import qualified Proto3.Suite.JSONPB as HsJSONPB -import Proto3.Suite.JSONPB ((.=), (.:)) -import qualified Proto3.Suite.Types as HsProtobuf -import qualified Proto3.Wire as HsProtobuf -import qualified Proto3.Wire.Decode as HsProtobuf - (Parser, RawField) + +import Control.Applicative ((<$>), (<*>), (<|>)) import qualified Control.Applicative as Hs -import Control.Applicative ((<*>), (<|>), (<$>)) import qualified Control.DeepSeq as Hs import qualified Control.Monad as Hs import qualified Data.ByteString as Hs import qualified Data.Coerce as Hs import qualified Data.Int as Hs (Int16, Int32, Int64) -import qualified Data.List.NonEmpty as Hs (NonEmpty(..)) +import qualified Data.List.NonEmpty as Hs (NonEmpty (..)) import qualified Data.Map as Hs (Map, mapKeysMonotonic) import qualified Data.Proxy as Proxy import qualified Data.String as Hs (fromString) @@ -36,197 +28,288 @@ import qualified Data.Vector as Hs (Vector) import qualified Data.Word as Hs (Word16, Word32, Word64) import qualified GHC.Enum as Hs import qualified GHC.Generics as Hs -import qualified Google.Protobuf.Wrappers.Polymorphic as HsProtobuf - (Wrapped(..)) -import qualified Unsafe.Coerce as Hs -import Network.GRPC.HighLevel.Generated as HsGRPC +import qualified Google.Protobuf.Wrappers.Polymorphic as HsProtobuf ( + Wrapped (..), + ) import Network.GRPC.HighLevel.Client as HsGRPC +import Network.GRPC.HighLevel.Generated as HsGRPC import Network.GRPC.HighLevel.Server as HsGRPC hiding (serverLoop) -import Network.GRPC.HighLevel.Server.Unregistered as HsGRPC - (serverLoop) - -data Arithmetic request response = Arithmetic{arithmeticAdd :: - request 'HsGRPC.Normal Arithmetic.TwoInts - Arithmetic.OneInt - -> - Hs.IO (response 'HsGRPC.Normal Arithmetic.OneInt), - arithmeticRunningSum :: - request 'HsGRPC.ClientStreaming Arithmetic.OneInt - Arithmetic.OneInt - -> - Hs.IO - (response 'HsGRPC.ClientStreaming - Arithmetic.OneInt)} - deriving Hs.Generic - +import Network.GRPC.HighLevel.Server.Unregistered as HsGRPC ( + serverLoop, + ) +import qualified Proto3.Suite.Class as HsProtobuf +import qualified Proto3.Suite.DotProto as HsProtobufAST +import Proto3.Suite.JSONPB ((.:), (.=)) +import qualified Proto3.Suite.JSONPB as HsJSONPB +import qualified Proto3.Suite.Types as HsProtobuf +import qualified Proto3.Wire as HsProtobuf +import qualified Proto3.Wire.Decode as HsProtobuf ( + Parser, + RawField, + ) +import qualified Unsafe.Coerce as Hs +import qualified Prelude as Hs + +data Arithmetic request response = Arithmetic + { arithmeticAdd :: + request + 'HsGRPC.Normal + Arithmetic.TwoInts + Arithmetic.OneInt -> + Hs.IO (response 'HsGRPC.Normal Arithmetic.OneInt) + , arithmeticRunningSum :: + request + 'HsGRPC.ClientStreaming + Arithmetic.OneInt + Arithmetic.OneInt -> + Hs.IO + ( response + 'HsGRPC.ClientStreaming + Arithmetic.OneInt + ) + } + deriving (Hs.Generic) + arithmeticServer :: - Arithmetic HsGRPC.ServerRequest HsGRPC.ServerResponse -> - HsGRPC.ServiceOptions -> Hs.IO () + Arithmetic HsGRPC.ServerRequest HsGRPC.ServerResponse -> + HsGRPC.ServiceOptions -> + Hs.IO () arithmeticServer - Arithmetic{arithmeticAdd = arithmeticAdd, - arithmeticRunningSum = arithmeticRunningSum} - (ServiceOptions serverHost serverPort useCompression - userAgentPrefix userAgentSuffix initialMetadata sslConfig logger - serverMaxReceiveMessageLength serverMaxMetadataSize) - = (HsGRPC.serverLoop - HsGRPC.defaultOptions{HsGRPC.optNormalHandlers = - [(HsGRPC.UnaryHandler - (HsGRPC.MethodName "/arithmetic.Arithmetic/Add") - (HsGRPC.convertGeneratedServerHandler arithmeticAdd))], - HsGRPC.optClientStreamHandlers = - [(HsGRPC.ClientStreamHandler - (HsGRPC.MethodName "/arithmetic.Arithmetic/RunningSum") - (HsGRPC.convertGeneratedServerReaderHandler - arithmeticRunningSum))], - HsGRPC.optServerStreamHandlers = [], - HsGRPC.optBiDiStreamHandlers = [], optServerHost = serverHost, - optServerPort = serverPort, optUseCompression = useCompression, - optUserAgentPrefix = userAgentPrefix, - optUserAgentSuffix = userAgentSuffix, - optInitialMetadata = initialMetadata, optSSLConfig = sslConfig, - optLogger = logger, - optMaxReceiveMessageLength = serverMaxReceiveMessageLength, - optMaxMetadataSize = serverMaxMetadataSize}) - + Arithmetic + { arithmeticAdd = arithmeticAdd + , arithmeticRunningSum = arithmeticRunningSum + } + ( ServiceOptions + serverHost + serverPort + useCompression + userAgentPrefix + userAgentSuffix + initialMetadata + sslConfig + logger + serverMaxReceiveMessageLength + serverMaxMetadataSize + ) = + ( HsGRPC.serverLoop + HsGRPC.defaultOptions + { HsGRPC.optNormalHandlers = + [ ( HsGRPC.UnaryHandler + (HsGRPC.MethodName "/arithmetic.Arithmetic/Add") + (HsGRPC.convertGeneratedServerHandler arithmeticAdd) + ) + ] + , HsGRPC.optClientStreamHandlers = + [ ( HsGRPC.ClientStreamHandler + (HsGRPC.MethodName "/arithmetic.Arithmetic/RunningSum") + ( HsGRPC.convertGeneratedServerReaderHandler + arithmeticRunningSum + ) + ) + ] + , HsGRPC.optServerStreamHandlers = [] + , HsGRPC.optBiDiStreamHandlers = [] + , optServerHost = serverHost + , optServerPort = serverPort + , optUseCompression = useCompression + , optUserAgentPrefix = userAgentPrefix + , optUserAgentSuffix = userAgentSuffix + , optInitialMetadata = initialMetadata + , optSSLConfig = sslConfig + , optLogger = logger + , optMaxReceiveMessageLength = serverMaxReceiveMessageLength + , optMaxMetadataSize = serverMaxMetadataSize + } + ) + arithmeticClient :: - HsGRPC.Client -> - Hs.IO (Arithmetic HsGRPC.ClientRequest HsGRPC.ClientResult) -arithmeticClient client - = (Hs.pure Arithmetic) <*> - ((Hs.pure (HsGRPC.clientRequest client)) <*> - (HsGRPC.clientRegisterMethod client - (HsGRPC.MethodName "/arithmetic.Arithmetic/Add"))) - <*> - ((Hs.pure (HsGRPC.clientRequest client)) <*> - (HsGRPC.clientRegisterMethod client - (HsGRPC.MethodName "/arithmetic.Arithmetic/RunningSum"))) - -data TwoInts = TwoInts{twoIntsX :: Hs.Int32, twoIntsY :: Hs.Int32} - deriving (Hs.Show, Hs.Eq, Hs.Ord, Hs.Generic) - + HsGRPC.Client -> + Hs.IO (Arithmetic HsGRPC.ClientRequest HsGRPC.ClientResult) +arithmeticClient client = + (Hs.pure Arithmetic) + <*> ( (Hs.pure (HsGRPC.clientRequest client)) + <*> ( HsGRPC.clientRegisterMethod + client + (HsGRPC.MethodName "/arithmetic.Arithmetic/Add") + ) + ) + <*> ( (Hs.pure (HsGRPC.clientRequest client)) + <*> ( HsGRPC.clientRegisterMethod + client + (HsGRPC.MethodName "/arithmetic.Arithmetic/RunningSum") + ) + ) + +data TwoInts = TwoInts {twoIntsX :: Hs.Int32, twoIntsY :: Hs.Int32} + deriving (Hs.Show, Hs.Eq, Hs.Ord, Hs.Generic) + instance Hs.NFData TwoInts - + instance HsProtobuf.Named TwoInts where - nameOf _ = (Hs.fromString "TwoInts") - + nameOf _ = (Hs.fromString "TwoInts") + instance HsProtobuf.HasDefault TwoInts - + instance HsProtobuf.Message TwoInts where - encodeMessage _ TwoInts{twoIntsX = twoIntsX, twoIntsY = twoIntsY} - = (Hs.mconcat - [(HsProtobuf.encodeMessageField (HsProtobuf.FieldNumber 1) - twoIntsX), - (HsProtobuf.encodeMessageField (HsProtobuf.FieldNumber 2) - twoIntsY)]) - decodeMessage _ - = (Hs.pure TwoInts) <*> - (HsProtobuf.at HsProtobuf.decodeMessageField - (HsProtobuf.FieldNumber 1)) - <*> - (HsProtobuf.at HsProtobuf.decodeMessageField - (HsProtobuf.FieldNumber 2)) - dotProto _ - = [(HsProtobufAST.DotProtoField (HsProtobuf.FieldNumber 1) - (HsProtobufAST.Prim HsProtobufAST.Int32) - (HsProtobufAST.Single "x") - [] - ""), - (HsProtobufAST.DotProtoField (HsProtobuf.FieldNumber 2) - (HsProtobufAST.Prim HsProtobufAST.Int32) - (HsProtobufAST.Single "y") - [] - "")] - + encodeMessage _ TwoInts{twoIntsX = twoIntsX, twoIntsY = twoIntsY} = + ( Hs.mconcat + [ ( HsProtobuf.encodeMessageField + (HsProtobuf.FieldNumber 1) + twoIntsX + ) + , ( HsProtobuf.encodeMessageField + (HsProtobuf.FieldNumber 2) + twoIntsY + ) + ] + ) + decodeMessage _ = + (Hs.pure TwoInts) + <*> ( HsProtobuf.at + HsProtobuf.decodeMessageField + (HsProtobuf.FieldNumber 1) + ) + <*> ( HsProtobuf.at + HsProtobuf.decodeMessageField + (HsProtobuf.FieldNumber 2) + ) + dotProto _ = + [ ( HsProtobufAST.DotProtoField + (HsProtobuf.FieldNumber 1) + (HsProtobufAST.Prim HsProtobufAST.Int32) + (HsProtobufAST.Single "x") + [] + "" + ) + , ( HsProtobufAST.DotProtoField + (HsProtobuf.FieldNumber 2) + (HsProtobufAST.Prim HsProtobufAST.Int32) + (HsProtobufAST.Single "y") + [] + "" + ) + ] + instance HsJSONPB.ToJSONPB TwoInts where - toJSONPB (TwoInts f1 f2) = (HsJSONPB.object ["x" .= f1, "y" .= f2]) - toEncodingPB (TwoInts f1 f2) - = (HsJSONPB.pairs ["x" .= f1, "y" .= f2]) - + toJSONPB (TwoInts f1 f2) = (HsJSONPB.object ["x" .= f1, "y" .= f2]) + toEncodingPB (TwoInts f1 f2) = + (HsJSONPB.pairs ["x" .= f1, "y" .= f2]) + instance HsJSONPB.FromJSONPB TwoInts where - parseJSONPB - = (HsJSONPB.withObject "TwoInts" - (\ obj -> (Hs.pure TwoInts) <*> obj .: "x" <*> obj .: "y")) - + parseJSONPB = + ( HsJSONPB.withObject + "TwoInts" + (\obj -> (Hs.pure TwoInts) <*> obj .: "x" <*> obj .: "y") + ) + instance HsJSONPB.ToJSON TwoInts where - toJSON = HsJSONPB.toAesonValue - toEncoding = HsJSONPB.toAesonEncoding - + toJSON = HsJSONPB.toAesonValue + toEncoding = HsJSONPB.toAesonEncoding + instance HsJSONPB.FromJSON TwoInts where - parseJSON = HsJSONPB.parseJSONPB - + parseJSON = HsJSONPB.parseJSONPB + instance HsJSONPB.ToSchema TwoInts where - declareNamedSchema _ - = do let declare_x = HsJSONPB.declareSchemaRef - twoIntsX <- declare_x Proxy.Proxy - let declare_y = HsJSONPB.declareSchemaRef - twoIntsY <- declare_y Proxy.Proxy - let _ = Hs.pure TwoInts <*> HsJSONPB.asProxy declare_x <*> - HsJSONPB.asProxy declare_y - Hs.return - (HsJSONPB.NamedSchema{HsJSONPB._namedSchemaName = - Hs.Just "TwoInts", - HsJSONPB._namedSchemaSchema = - Hs.mempty{HsJSONPB._schemaParamSchema = - Hs.mempty{HsJSONPB._paramSchemaType = - Hs.Just HsJSONPB.SwaggerObject}, - HsJSONPB._schemaProperties = - HsJSONPB.insOrdFromList - [("x", twoIntsX), ("y", twoIntsY)]}}) - -newtype OneInt = OneInt{oneIntResult :: Hs.Int32} - deriving (Hs.Show, Hs.Eq, Hs.Ord, Hs.Generic) - + declareNamedSchema _ = + do + let declare_x = HsJSONPB.declareSchemaRef + twoIntsX <- declare_x Proxy.Proxy + let declare_y = HsJSONPB.declareSchemaRef + twoIntsY <- declare_y Proxy.Proxy + let _ = + Hs.pure TwoInts + <*> HsJSONPB.asProxy declare_x + <*> HsJSONPB.asProxy declare_y + Hs.return + ( HsJSONPB.NamedSchema + { HsJSONPB._namedSchemaName = + Hs.Just "TwoInts" + , HsJSONPB._namedSchemaSchema = + Hs.mempty + { HsJSONPB._schemaParamSchema = + Hs.mempty + { HsJSONPB._paramSchemaType = + Hs.Just HsJSONPB.SwaggerObject + } + , HsJSONPB._schemaProperties = + HsJSONPB.insOrdFromList + [("x", twoIntsX), ("y", twoIntsY)] + } + } + ) + +newtype OneInt = OneInt {oneIntResult :: Hs.Int32} + deriving (Hs.Show, Hs.Eq, Hs.Ord, Hs.Generic) + instance Hs.NFData OneInt - + instance HsProtobuf.Named OneInt where - nameOf _ = (Hs.fromString "OneInt") - + nameOf _ = (Hs.fromString "OneInt") + instance HsProtobuf.HasDefault OneInt - + instance HsProtobuf.Message OneInt where - encodeMessage _ OneInt{oneIntResult = oneIntResult} - = (Hs.mconcat - [(HsProtobuf.encodeMessageField (HsProtobuf.FieldNumber 1) - oneIntResult)]) - decodeMessage _ - = (Hs.pure OneInt) <*> - (HsProtobuf.at HsProtobuf.decodeMessageField - (HsProtobuf.FieldNumber 1)) - dotProto _ - = [(HsProtobufAST.DotProtoField (HsProtobuf.FieldNumber 1) - (HsProtobufAST.Prim HsProtobufAST.Int32) - (HsProtobufAST.Single "result") - [] - "")] - + encodeMessage _ OneInt{oneIntResult = oneIntResult} = + ( Hs.mconcat + [ ( HsProtobuf.encodeMessageField + (HsProtobuf.FieldNumber 1) + oneIntResult + ) + ] + ) + decodeMessage _ = + (Hs.pure OneInt) + <*> ( HsProtobuf.at + HsProtobuf.decodeMessageField + (HsProtobuf.FieldNumber 1) + ) + dotProto _ = + [ ( HsProtobufAST.DotProtoField + (HsProtobuf.FieldNumber 1) + (HsProtobufAST.Prim HsProtobufAST.Int32) + (HsProtobufAST.Single "result") + [] + "" + ) + ] + instance HsJSONPB.ToJSONPB OneInt where - toJSONPB (OneInt f1) = (HsJSONPB.object ["result" .= f1]) - toEncodingPB (OneInt f1) = (HsJSONPB.pairs ["result" .= f1]) - + toJSONPB (OneInt f1) = (HsJSONPB.object ["result" .= f1]) + toEncodingPB (OneInt f1) = (HsJSONPB.pairs ["result" .= f1]) + instance HsJSONPB.FromJSONPB OneInt where - parseJSONPB - = (HsJSONPB.withObject "OneInt" - (\ obj -> (Hs.pure OneInt) <*> obj .: "result")) - + parseJSONPB = + ( HsJSONPB.withObject + "OneInt" + (\obj -> (Hs.pure OneInt) <*> obj .: "result") + ) + instance HsJSONPB.ToJSON OneInt where - toJSON = HsJSONPB.toAesonValue - toEncoding = HsJSONPB.toAesonEncoding - + toJSON = HsJSONPB.toAesonValue + toEncoding = HsJSONPB.toAesonEncoding + instance HsJSONPB.FromJSON OneInt where - parseJSON = HsJSONPB.parseJSONPB - -instance HsJSONPB.ToSchema OneInt where - declareNamedSchema _ - = do let declare_result = HsJSONPB.declareSchemaRef - oneIntResult <- declare_result Proxy.Proxy - let _ = Hs.pure OneInt <*> HsJSONPB.asProxy declare_result - Hs.return - (HsJSONPB.NamedSchema{HsJSONPB._namedSchemaName = Hs.Just "OneInt", - HsJSONPB._namedSchemaSchema = - Hs.mempty{HsJSONPB._schemaParamSchema = - Hs.mempty{HsJSONPB._paramSchemaType = - Hs.Just HsJSONPB.SwaggerObject}, - HsJSONPB._schemaProperties = - HsJSONPB.insOrdFromList - [("result", oneIntResult)]}}) + parseJSON = HsJSONPB.parseJSONPB +instance HsJSONPB.ToSchema OneInt where + declareNamedSchema _ = + do + let declare_result = HsJSONPB.declareSchemaRef + oneIntResult <- declare_result Proxy.Proxy + let _ = Hs.pure OneInt <*> HsJSONPB.asProxy declare_result + Hs.return + ( HsJSONPB.NamedSchema + { HsJSONPB._namedSchemaName = Hs.Just "OneInt" + , HsJSONPB._namedSchemaSchema = + Hs.mempty + { HsJSONPB._schemaParamSchema = + Hs.mempty + { HsJSONPB._paramSchemaType = + Hs.Just HsJSONPB.SwaggerObject + } + , HsJSONPB._schemaProperties = + HsJSONPB.insOrdFromList + [("result", oneIntResult)] + } + } + ) diff --git a/examples/tutorial/ArithmeticClient.hs b/examples/tutorial/ArithmeticClient.hs index ce99a028..bee8fb54 100644 --- a/examples/tutorial/ArithmeticClient.hs +++ b/examples/tutorial/ArithmeticClient.hs @@ -1,40 +1,46 @@ -{-# LANGUAGE GADTs #-} -{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE RecordWildCards #-} -import Arithmetic -import Network.GRPC.HighLevel.Generated +import Arithmetic +import Network.GRPC.HighLevel.Generated clientConfig :: ClientConfig -clientConfig = ClientConfig { clientServerEndpoint = "localhost:50051" - , clientArgs = [] - , clientSSLConfig = Nothing - , clientAuthority = Nothing - } +clientConfig = + ClientConfig + { clientServerEndpoint = "localhost:50051" + , clientArgs = [] + , clientSSLConfig = Nothing + , clientAuthority = Nothing + } main :: IO () main = withGRPCClient clientConfig $ \client -> do Arithmetic{..} <- arithmeticClient client -- Request for the Add RPC - ClientNormalResponse (OneInt x) _meta1 _meta2 _status _details - <- arithmeticAdd (ClientNormalRequest (TwoInts 2 2) 1 []) + ClientNormalResponse (OneInt x) _meta1 _meta2 _status _details <- + arithmeticAdd (ClientNormalRequest (TwoInts 2 2) 1 []) putStrLn ("2 + 2 = " ++ show x) -- Request for the RunningSum RPC - ClientWriterResponse reply _streamMeta1 _streamMeta2 streamStatus streamDtls - <- arithmeticRunningSum $ ClientWriterRequest 1 [] $ \send -> do - eithers <- mapM send [OneInt 1, OneInt 2, OneInt 3] - :: IO [Either GRPCIOError ()] - case sequence eithers of - Left err -> error ("Error while streaming: " ++ show err) - Right _ -> return () + ClientWriterResponse reply _streamMeta1 _streamMeta2 streamStatus streamDtls <- + arithmeticRunningSum $ ClientWriterRequest 1 [] $ \send -> do + eithers <- + mapM send [OneInt 1, OneInt 2, OneInt 3] :: + IO [Either GRPCIOError ()] + case sequence eithers of + Left err -> error ("Error while streaming: " ++ show err) + Right _ -> return () case reply of Just (OneInt y) -> print ("1 + 2 + 3 = " ++ show y) - Nothing -> putStrLn ("Client stream failed with status " - ++ show streamStatus - ++ " and details " - ++ show streamDtls) + Nothing -> + putStrLn + ( "Client stream failed with status " + ++ show streamStatus + ++ " and details " + ++ show streamDtls + ) return () diff --git a/examples/tutorial/ArithmeticServer.hs b/examples/tutorial/ArithmeticServer.hs index 47ba2e0c..07d5d892 100644 --- a/examples/tutorial/ArithmeticServer.hs +++ b/examples/tutorial/ArithmeticServer.hs @@ -1,8 +1,8 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE OverloadedLists #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE BangPatterns #-} import Arithmetic import Network.GRPC.HighLevel.Generated @@ -10,38 +10,52 @@ import Network.GRPC.HighLevel.Generated import Data.String (fromString) handlers :: Arithmetic ServerRequest ServerResponse -handlers = Arithmetic { arithmeticAdd = addHandler - , arithmeticRunningSum = runningSumHandler - } +handlers = + Arithmetic + { arithmeticAdd = addHandler + , arithmeticRunningSum = runningSumHandler + } -addHandler :: ServerRequest 'Normal TwoInts OneInt - -> IO (ServerResponse 'Normal OneInt) +addHandler :: + ServerRequest 'Normal TwoInts OneInt -> + IO (ServerResponse 'Normal OneInt) addHandler (ServerNormalRequest _metadata (TwoInts x y)) = do let answer = OneInt (x + y) - return (ServerNormalResponse answer - [("metadata_key_one", "metadata_value")] - StatusOk - "addition is easy!") - + return + ( ServerNormalResponse + answer + [("metadata_key_one", "metadata_value")] + StatusOk + "addition is easy!" + ) -runningSumHandler :: ServerRequest 'ClientStreaming OneInt OneInt - -> IO (ServerResponse 'ClientStreaming OneInt) +runningSumHandler :: + ServerRequest 'ClientStreaming OneInt OneInt -> + IO (ServerResponse 'ClientStreaming OneInt) runningSumHandler (ServerReaderRequest _metadata recv) = loop 0 - where loop !i = - do msg <- recv - case msg of - Left err -> return (ServerReaderResponse - Nothing - [] - StatusUnknown - (fromString (show err))) - Right (Just (OneInt x)) -> loop (i + x) - Right Nothing -> return (ServerReaderResponse - (Just (OneInt i)) - [] - StatusOk - "") + where + loop !i = + do + msg <- recv + case msg of + Left err -> + return + ( ServerReaderResponse + Nothing + [] + StatusUnknown + (fromString (show err)) + ) + Right (Just (OneInt x)) -> loop (i + x) + Right Nothing -> + return + ( ServerReaderResponse + (Just (OneInt i)) + [] + StatusOk + "" + ) options :: ServiceOptions options = defaultServiceOptions diff --git a/fourmolu.yaml b/fourmolu.yaml new file mode 100644 index 00000000..e77d2a57 --- /dev/null +++ b/fourmolu.yaml @@ -0,0 +1,51 @@ +# Number of spaces per indentation step +indentation: 2 + +# Max line length for automatic line breaking +column-limit: none + +# Styling of arrows in type signatures (choices: trailing, leading, or leading-args) +function-arrows: trailing + +# How to place commas in multi-line lists, records, etc. (choices: leading or trailing) +comma-style: leading + +# Styling of import/export lists (choices: leading, trailing, or diff-friendly) +import-export-style: diff-friendly + +# Whether to full-indent or half-indent 'where' bindings past the preceding body +indent-wheres: true + +# Whether to leave a space before an opening record brace +record-brace-space: false + +# Number of spaces between top-level declarations +newlines-between-decls: 1 + +# How to print Haddock comments (choices: single-line, multi-line, or multi-line-compact) +haddock-style: single-line + +# How to print module docstring +haddock-style-module: null + +# Styling of let blocks (choices: auto, inline, newline, or mixed) +let-style: auto + +# How to align the 'in' keyword with respect to the 'let' keyword (choices: left-align, right-align, or no-space) +in-style: right-align + +# Whether to put parentheses around a single constraint (choices: auto, always, or never) +single-constraint-parens: always + +# Output Unicode syntax (choices: detect, always, or never) +unicode: never + +# Give the programmer more choice on where to insert blank lines +respectful: true + +# Fixity information for operators +fixities: [] + +# Module reexports Fourmolu should know about +reexports: + diff --git a/src/Network/GRPC/HighLevel.hs b/src/Network/GRPC/HighLevel.hs index e03375c2..d26ee3fb 100644 --- a/src/Network/GRPC/HighLevel.hs +++ b/src/Network/GRPC/HighLevel.hs @@ -1,54 +1,53 @@ module Network.GRPC.HighLevel ( - --- * Types - MetadataMap(..) -, MethodName(..) -, StatusDetails(..) -, StatusCode(..) -, GRPCIOError(..) -, GRPCImpl(..) -, MkHandler -, ServiceOptions(..) - --- * Server -, Handler(..) -, ServerOptions(..) -, defaultOptions -, serverLoop -, ServerCall(..) -, serverCallCancel -, serverCallIsExpired - --- * Client -, NormalRequestResult(..) -, ClientCall -, clientCallCancel - --- * Client and Server Auth -, AuthContext -, AuthProperty(..) -, getAuthProperties -, addAuthProperty - --- * Server Auth -, ServerSSLConfig(..) -, ProcessMeta -, AuthProcessorResult(..) -, SslClientCertificateRequestType(..) - --- * Client Auth -, ClientSSLConfig(..) -, ClientSSLKeyCertPair(..) -, ClientMetadataCreate -, ClientMetadataCreateResult(..) -, AuthMetadataContext(..) - --- * Streaming utilities -, StreamSend -, StreamRecv + -- * Types + MetadataMap (..), + MethodName (..), + StatusDetails (..), + StatusCode (..), + GRPCIOError (..), + GRPCImpl (..), + MkHandler, + ServiceOptions (..), + + -- * Server + Handler (..), + ServerOptions (..), + defaultOptions, + serverLoop, + ServerCall (..), + serverCallCancel, + serverCallIsExpired, + + -- * Client + NormalRequestResult (..), + ClientCall, + clientCallCancel, + + -- * Client and Server Auth + AuthContext, + AuthProperty (..), + getAuthProperties, + addAuthProperty, + + -- * Server Auth + ServerSSLConfig (..), + ProcessMeta, + AuthProcessorResult (..), + SslClientCertificateRequestType (..), + + -- * Client Auth + ClientSSLConfig (..), + ClientSSLKeyCertPair (..), + ClientMetadataCreate, + ClientMetadataCreateResult (..), + AuthMetadataContext (..), + + -- * Streaming utilities + StreamSend, + StreamRecv, ) - where +where -import Network.GRPC.HighLevel.Server -import Network.GRPC.HighLevel.Generated -import Network.GRPC.LowLevel +import Network.GRPC.HighLevel.Generated +import Network.GRPC.HighLevel.Server +import Network.GRPC.LowLevel diff --git a/src/Network/GRPC/HighLevel/Client.hs b/src/Network/GRPC/HighLevel/Client.hs index 1067ae16..4e188b8e 100644 --- a/src/Network/GRPC/HighLevel/Client.hs +++ b/src/Network/GRPC/HighLevel/Client.hs @@ -1,63 +1,71 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE ScopedTypeVariables #-} -module Network.GRPC.HighLevel.Client - ( ClientError(..) - , ClientRegisterable(..) - , ClientRequest(..) - , ClientResult(..) - , GRPCMethodType(..) - , MetadataMap(..) - , RegisteredMethod - , ServiceClient - , StatusCode(..) - , StatusDetails(..) - , StreamRecv - , StreamSend - , TimeoutSeconds - , WritesDone - - , LL.Client - , LL.ClientConfig(..) - , LL.ClientSSLConfig(..) - , LL.ClientSSLKeyCertPair(..) - , LL.Host(..) - , LL.Port(..) - - , clientRequest +module Network.GRPC.HighLevel.Client ( + ClientError (..), + ClientRegisterable (..), + ClientRequest (..), + ClientResult (..), + GRPCMethodType (..), + MetadataMap (..), + RegisteredMethod, + ServiceClient, + StatusCode (..), + StatusDetails (..), + StreamRecv, + StreamSend, + TimeoutSeconds, + WritesDone, + LL.Client, + LL.ClientConfig (..), + LL.ClientSSLConfig (..), + LL.ClientSSLKeyCertPair (..), + LL.Host (..), + LL.Port (..), + clientRequest, -- * Client utility functions - , acquireClient - , simplifyServerStreaming - , simplifyUnary - ) - + acquireClient, + simplifyServerStreaming, + simplifyUnary, +) where -import Control.Monad.Managed (Managed, liftIO, - managed) -import qualified Data.ByteString.Lazy as BL -import Network.GRPC.HighLevel.Server (convertRecv, - convertSend) -import Network.GRPC.LowLevel (GRPCIOError (..), - GRPCMethodType (..), - MetadataMap (..), - StatusCode (..), - StatusDetails (..), - StreamRecv, StreamSend) -import qualified Network.GRPC.LowLevel as LL -import Network.GRPC.LowLevel.CompletionQueue (TimeoutSeconds) -import Network.GRPC.LowLevel.Op (WritesDone) -import Proto3.Suite (Message, fromByteString, - toLazyByteString) -import Proto3.Wire.Decode (ParseError) +import Control.Monad.Managed ( + Managed, + liftIO, + managed, + ) +import qualified Data.ByteString.Lazy as BL +import Network.GRPC.HighLevel.Server ( + convertRecv, + convertSend, + ) +import Network.GRPC.LowLevel ( + GRPCIOError (..), + GRPCMethodType (..), + MetadataMap (..), + StatusCode (..), + StatusDetails (..), + StreamRecv, + StreamSend, + ) +import qualified Network.GRPC.LowLevel as LL +import Network.GRPC.LowLevel.CompletionQueue (TimeoutSeconds) +import Network.GRPC.LowLevel.Op (WritesDone) +import Proto3.Suite ( + Message, + fromByteString, + toLazyByteString, + ) +import Proto3.Wire.Decode (ParseError) newtype RegisteredMethod (mt :: GRPCMethodType) request response = RegisteredMethod (LL.RegisteredMethod mt) - deriving Show + deriving (Show) type ServiceClient service = service ClientRequest ClientResult @@ -79,13 +87,14 @@ data ClientResult (streamType :: GRPCMethodType) response where ClientNormalResponse :: response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'Normal response ClientWriterResponse :: Maybe response -> MetadataMap -> MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'ClientStreaming response ClientReaderResponse :: MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'ServerStreaming response - ClientBiDiResponse :: MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'BiDiStreaming response - ClientErrorResponse :: ClientError -> ClientResult streamType response + ClientBiDiResponse :: MetadataMap -> StatusCode -> StatusDetails -> ClientResult 'BiDiStreaming response + ClientErrorResponse :: ClientError -> ClientResult streamType response class ClientRegisterable (methodType :: GRPCMethodType) where - clientRegisterMethod :: LL.Client - -> LL.MethodName - -> IO (RegisteredMethod methodType request response) + clientRegisterMethod :: + LL.Client -> + LL.MethodName -> + IO (RegisteredMethod methodType request response) instance ClientRegisterable 'Normal where clientRegisterMethod client methodName = @@ -103,11 +112,14 @@ instance ClientRegisterable 'BiDiStreaming where clientRegisterMethod client methodName = RegisteredMethod <$> LL.clientRegisterMethodBiDiStreaming client methodName -clientRequest :: (Message request, Message response) => - LL.Client -> RegisteredMethod streamType request response - -> ClientRequest streamType request response -> IO (ClientResult streamType response) +clientRequest :: + (Message request, Message response) => + LL.Client -> + RegisteredMethod streamType request response -> + ClientRequest streamType request response -> + IO (ClientResult streamType response) clientRequest client (RegisteredMethod method) (ClientNormalRequest req timeout meta) = - mkResponse <$> LL.clientRequest client method timeout (BL.toStrict (toLazyByteString req)) meta + mkResponse <$> LL.clientRequest client method timeout (BL.toStrict (toLazyByteString req)) meta where mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_) mkResponse (Right rsp) = @@ -116,7 +128,7 @@ clientRequest client (RegisteredMethod method) (ClientNormalRequest req timeout Right parsedRsp -> ClientNormalResponse parsedRsp (LL.initMD rsp) (LL.trailMD rsp) (LL.rspCode rsp) (LL.details rsp) clientRequest client (RegisteredMethod method) (ClientWriterRequest timeout meta handler) = - mkResponse <$> LL.clientWriter client method timeout meta (handler . convertSend) + mkResponse <$> LL.clientWriter client method timeout meta (handler . convertSend) where mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_) mkResponse (Right (rsp_, initMD_, trailMD_, rspCode_, details_)) = @@ -125,24 +137,24 @@ clientRequest client (RegisteredMethod method) (ClientWriterRequest timeout meta Right parsedRsp -> ClientWriterResponse parsedRsp initMD_ trailMD_ rspCode_ details_ clientRequest client (RegisteredMethod method) (ClientReaderRequest req timeout meta handler) = - mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\cc m recv -> handler cc m (convertRecv recv)) + mkResponse <$> LL.clientReader client method timeout (BL.toStrict (toLazyByteString req)) meta (\cc m recv -> handler cc m (convertRecv recv)) where mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_) mkResponse (Right (meta_, rspCode_, details_)) = ClientReaderResponse meta_ rspCode_ details_ clientRequest client (RegisteredMethod method) (ClientBiDiRequest timeout meta handler) = - mkResponse <$> LL.clientRW client method timeout meta (\cc _m recv send writesDone -> handler cc meta (convertRecv recv) (convertSend send) writesDone) + mkResponse <$> LL.clientRW client method timeout meta (\cc _m recv send writesDone -> handler cc meta (convertRecv recv) (convertSend send) writesDone) where mkResponse (Left ioError_) = ClientErrorResponse (ClientIOError ioError_) mkResponse (Right (meta_, rspCode_, details_)) = ClientBiDiResponse meta_ rspCode_ details_ -acquireClient - :: LL.ClientConfig - -- ^ The client configuration (host, port, SSL settings, etc) - -> (LL.Client -> IO (ServiceClient service)) - -- ^ The client implementation (typically generated) - -> Managed (ServiceClient service) +acquireClient :: + -- | The client configuration (host, port, SSL settings, etc) + LL.ClientConfig -> + -- | The client implementation (typically generated) + (LL.Client -> IO (ServiceClient service)) -> + Managed (ServiceClient service) acquireClient cfg impl = do g <- managed LL.withGRPC c <- managed (LL.withClient g cfg) @@ -151,67 +163,64 @@ acquireClient cfg impl = do -- | A utility for simplifying server-streaming gRPC client requests; you can -- use this to avoid 'ClientRequest' and 'ClientResult' pattern-matching -- boilerplate at call sites. -simplifyServerStreaming :: TimeoutSeconds - -- ^ RPC call timeout, in seconds - -> MetadataMap - -- ^ RPC call metadata - -> (ClientError -> IO StatusDetails) - -- ^ Handler for client errors - -> (StatusCode -> StatusDetails -> IO StatusDetails) - -- ^ Handler for non-StatusOk response - -> (ClientRequest 'ServerStreaming request response - -> IO (ClientResult 'ServerStreaming response)) - -- ^ Endpoint implementation (typically generated by grpc-haskell) - -> request - -- ^ Request payload - -> (LL.ClientCall -> MetadataMap -> StreamRecv response -> IO ()) - -- ^ Stream handler; note that the 'StreamRecv' - -- action must be called repeatedly in order to - -- consume the stream - -> IO StatusDetails +simplifyServerStreaming :: + -- | RPC call timeout, in seconds + TimeoutSeconds -> + -- | RPC call metadata + MetadataMap -> + -- | Handler for client errors + (ClientError -> IO StatusDetails) -> + -- | Handler for non-StatusOk response + (StatusCode -> StatusDetails -> IO StatusDetails) -> + -- | Endpoint implementation (typically generated by grpc-haskell) + ( ClientRequest 'ServerStreaming request response -> + IO (ClientResult 'ServerStreaming response) + ) -> + -- | Request payload + request -> + -- | Stream handler; note that the 'StreamRecv' + -- action must be called repeatedly in order to + -- consume the stream + (LL.ClientCall -> MetadataMap -> StreamRecv response -> IO ()) -> + IO StatusDetails simplifyServerStreaming timeout meta clientError nonStatusOkError f x handler = do - let request = ClientReaderRequest x timeout meta handler response <- f request case response of - ClientReaderResponse _ StatusOk details - -> pure details - - ClientReaderResponse _ statusCode details - -> nonStatusOkError statusCode details - - ClientErrorResponse err - -> clientError err + ClientReaderResponse _ StatusOk details -> + pure details + ClientReaderResponse _ statusCode details -> + nonStatusOkError statusCode details + ClientErrorResponse err -> + clientError err -- | A utility for simplifying unary gRPC client requests; you can use this to -- avoid 'ClientRequest' and 'ClientResult' pattern-matching boilerplate at -- call sites. -simplifyUnary :: TimeoutSeconds - -- ^ RPC call timeout, in seconds - -> MetadataMap - -- ^ RPC call metadata - -> (ClientError -> IO (response, StatusDetails)) - -- ^ Handler for client errors - -> (response -> StatusCode -> StatusDetails -> IO (response, StatusDetails)) - -- ^ Handler for non-StatusOK responses - -> (ClientRequest 'Normal request response -> IO (ClientResult 'Normal response)) - -- ^ Endpoint implementation (typically generated by grpc-haskell) - -> (request -> IO (response, StatusDetails)) - -- ^ The simplified happy-path (StatusOk) unary call action +simplifyUnary :: + -- | RPC call timeout, in seconds + TimeoutSeconds -> + -- | RPC call metadata + MetadataMap -> + -- | Handler for client errors + (ClientError -> IO (response, StatusDetails)) -> + -- | Handler for non-StatusOK responses + (response -> StatusCode -> StatusDetails -> IO (response, StatusDetails)) -> + -- | Endpoint implementation (typically generated by grpc-haskell) + (ClientRequest 'Normal request response -> IO (ClientResult 'Normal response)) -> + -- | The simplified happy-path (StatusOk) unary call action + (request -> IO (response, StatusDetails)) simplifyUnary timeout meta clientError nonStatusOkError f x = do - let request = ClientNormalRequest x timeout meta response <- f request case response of - ClientNormalResponse y _ _ StatusOk details - -> pure (y, details) - - ClientNormalResponse y _ _ code details - -> nonStatusOkError y code details - - ClientErrorResponse err - -> clientError err + ClientNormalResponse y _ _ StatusOk details -> + pure (y, details) + ClientNormalResponse y _ _ code details -> + nonStatusOkError y code details + ClientErrorResponse err -> + clientError err diff --git a/src/Network/GRPC/HighLevel/Generated.hs b/src/Network/GRPC/HighLevel/Generated.hs index a635f511..4a40c2e3 100644 --- a/src/Network/GRPC/HighLevel/Generated.hs +++ b/src/Network/GRPC/HighLevel/Generated.hs @@ -1,49 +1,48 @@ - -{-# LANGUAGE TypeFamilies #-} {-# LANGUAGE DataKinds #-} -{-# LANGUAGE RankNTypes #-} {-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE TypeFamilies #-} module Network.GRPC.HighLevel.Generated ( -- * Types - MetadataMap(..) -, MethodName(..) -, GRPCMethodType(..) -, GRPCImpl(..) -, MkHandler -, Host(..) -, Port(..) -, StatusDetails(..) -, StatusCode(..) -, GRPCIOError(..) + MetadataMap (..), + MethodName (..), + GRPCMethodType (..), + GRPCImpl (..), + MkHandler, + Host (..), + Port (..), + StatusDetails (..), + StatusCode (..), + GRPCIOError (..), -- * Server -, ServiceOptions(..) -, defaultServiceOptions -, ServerCall(..) -, serverCallCancel -, serverCallIsExpired -, ServerRequest(..) -, ServerResponse(..) + ServiceOptions (..), + defaultServiceOptions, + ServerCall (..), + serverCallCancel, + serverCallIsExpired, + ServerRequest (..), + ServerResponse (..), -- * Server Auth -, ServerSSLConfig(..) + ServerSSLConfig (..), -- * Client -, withGRPCClient -, ClientConfig(..) -, ClientError(..) -, ClientRequest(..) -, ClientResult(..) + withGRPCClient, + ClientConfig (..), + ClientError (..), + ClientRequest (..), + ClientResult (..), ) where -import Network.GRPC.HighLevel.Server -import Network.GRPC.HighLevel.Client -import Network.GRPC.LowLevel -import Network.GRPC.LowLevel.Call -import Numeric.Natural -import System.IO (hPutStrLn, stderr) +import Network.GRPC.HighLevel.Client +import Network.GRPC.HighLevel.Server +import Network.GRPC.LowLevel +import Network.GRPC.LowLevel.Call +import Numeric.Natural +import System.IO (hPutStrLn, stderr) -- | Used at the kind level as a parameter to service definitions -- generated by the grpc compiler, with the effect of having the @@ -55,50 +54,51 @@ data GRPCImpl = ServerImpl | ClientImpl -- 'interpreter' type fully applied to get the same effect. type family MkHandler (impl :: GRPCImpl) (methodType :: GRPCMethodType) i o -type instance MkHandler 'ServerImpl 'Normal i o = ServerHandler i o +type instance MkHandler 'ServerImpl 'Normal i o = ServerHandler i o type instance MkHandler 'ServerImpl 'ClientStreaming i o = ServerReaderHandler i o type instance MkHandler 'ServerImpl 'ServerStreaming i o = ServerWriterHandler i o -type instance MkHandler 'ServerImpl 'BiDiStreaming i o = ServerRWHandler i o +type instance MkHandler 'ServerImpl 'BiDiStreaming i o = ServerRWHandler i o -- | Options for a service that was generated from a .proto file. This is -- essentially 'ServerOptions' with the handler fields removed. data ServiceOptions = ServiceOptions - { serverHost :: Host - -- ^ Name of the host the server is running on. - , serverPort :: Port - -- ^ Port on which to listen for requests. - , useCompression :: Bool - -- ^ Whether to use compression when communicating with the client. - , userAgentPrefix :: String - -- ^ Optional custom prefix to add to the user agent string. - , userAgentSuffix :: String - -- ^ Optional custom suffix to add to the user agent string. - , initialMetadata :: MetadataMap - -- ^ Metadata to send at the beginning of each call. - , sslConfig :: Maybe ServerSSLConfig - -- ^ Security configuration. - , logger :: String -> IO () - -- ^ Logging function to use to log errors in handling calls. + { serverHost :: Host + -- ^ Name of the host the server is running on. + , serverPort :: Port + -- ^ Port on which to listen for requests. + , useCompression :: Bool + -- ^ Whether to use compression when communicating with the client. + , userAgentPrefix :: String + -- ^ Optional custom prefix to add to the user agent string. + , userAgentSuffix :: String + -- ^ Optional custom suffix to add to the user agent string. + , initialMetadata :: MetadataMap + -- ^ Metadata to send at the beginning of each call. + , sslConfig :: Maybe ServerSSLConfig + -- ^ Security configuration. + , logger :: String -> IO () + -- ^ Logging function to use to log errors in handling calls. , serverMaxReceiveMessageLength :: Maybe Natural - -- ^ Maximum length (in bytes) that the service may receive in a single message. + -- ^ Maximum length (in bytes) that the service may receive in a single message. , serverMaxMetadataSize :: Maybe Natural - -- ^ Maximum metadata size (in bytes) that the service may receive in a single request. + -- ^ Maximum metadata size (in bytes) that the service may receive in a single request. } defaultServiceOptions :: ServiceOptions -defaultServiceOptions = ServiceOptions - -- names are fully qualified because we use the same fields in LowLevel. - { Network.GRPC.HighLevel.Generated.serverHost = "localhost" - , Network.GRPC.HighLevel.Generated.serverPort = 50051 - , Network.GRPC.HighLevel.Generated.useCompression = False - , Network.GRPC.HighLevel.Generated.userAgentPrefix = "grpc-haskell/0.0.0" - , Network.GRPC.HighLevel.Generated.userAgentSuffix = "" - , Network.GRPC.HighLevel.Generated.initialMetadata = mempty - , Network.GRPC.HighLevel.Generated.sslConfig = Nothing - , Network.GRPC.HighLevel.Generated.logger = hPutStrLn stderr - , Network.GRPC.HighLevel.Generated.serverMaxReceiveMessageLength = Nothing - , Network.GRPC.HighLevel.Generated.serverMaxMetadataSize = Nothing - } +defaultServiceOptions = + ServiceOptions + { -- names are fully qualified because we use the same fields in LowLevel. + Network.GRPC.HighLevel.Generated.serverHost = "localhost" + , Network.GRPC.HighLevel.Generated.serverPort = 50051 + , Network.GRPC.HighLevel.Generated.useCompression = False + , Network.GRPC.HighLevel.Generated.userAgentPrefix = "grpc-haskell/0.0.0" + , Network.GRPC.HighLevel.Generated.userAgentSuffix = "" + , Network.GRPC.HighLevel.Generated.initialMetadata = mempty + , Network.GRPC.HighLevel.Generated.sslConfig = Nothing + , Network.GRPC.HighLevel.Generated.logger = hPutStrLn stderr + , Network.GRPC.HighLevel.Generated.serverMaxReceiveMessageLength = Nothing + , Network.GRPC.HighLevel.Generated.serverMaxMetadataSize = Nothing + } withGRPCClient :: ClientConfig -> (Client -> IO a) -> IO a withGRPCClient c f = withGRPC $ \grpc -> withClient grpc c $ \client -> f client diff --git a/src/Network/GRPC/HighLevel/Server.hs b/src/Network/GRPC/HighLevel/Server.hs index 480efb6a..335a9ca4 100644 --- a/src/Network/GRPC/HighLevel/Server.hs +++ b/src/Network/GRPC/HighLevel/Server.hs @@ -1,21 +1,21 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE KindSignatures #-} -{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE KindSignatures #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} module Network.GRPC.HighLevel.Server where -import qualified Control.Exception as CE -import Control.Monad -import Data.ByteString (ByteString) -import qualified Data.ByteString.Lazy as BL -import Network.GRPC.LowLevel -import Numeric.Natural -import Proto3.Suite.Class -import System.IO +import qualified Control.Exception as CE +import Control.Monad +import Data.ByteString (ByteString) +import qualified Data.ByteString.Lazy as BL +import Network.GRPC.LowLevel +import Numeric.Natural +import Proto3.Suite.Class +import System.IO type ServerCallMetadata = ServerCall () @@ -28,122 +28,145 @@ data ServerRequest (streamType :: GRPCMethodType) request response where ServerBiDiRequest :: ServerCallMetadata -> StreamRecv request -> StreamSend response -> ServerRequest 'BiDiStreaming request response data ServerResponse (streamType :: GRPCMethodType) response where - ServerNormalResponse :: response -> MetadataMap -> StatusCode -> StatusDetails - -> ServerResponse 'Normal response - ServerReaderResponse :: Maybe response -> MetadataMap -> StatusCode -> StatusDetails - -> ServerResponse 'ClientStreaming response - ServerWriterResponse :: MetadataMap -> StatusCode -> StatusDetails - -> ServerResponse 'ServerStreaming response - ServerBiDiResponse :: MetadataMap -> StatusCode -> StatusDetails - -> ServerResponse 'BiDiStreaming response + ServerNormalResponse :: + response -> + MetadataMap -> + StatusCode -> + StatusDetails -> + ServerResponse 'Normal response + ServerReaderResponse :: + Maybe response -> + MetadataMap -> + StatusCode -> + StatusDetails -> + ServerResponse 'ClientStreaming response + ServerWriterResponse :: + MetadataMap -> + StatusCode -> + StatusDetails -> + ServerResponse 'ServerStreaming response + ServerBiDiResponse :: + MetadataMap -> + StatusCode -> + StatusDetails -> + ServerResponse 'BiDiStreaming response type ServerHandler a b = - ServerCall a - -> IO (b, MetadataMap, StatusCode, StatusDetails) + ServerCall a -> + IO (b, MetadataMap, StatusCode, StatusDetails) convertGeneratedServerHandler :: - (ServerRequest 'Normal request response -> IO (ServerResponse 'Normal response)) - -> ServerHandler request response + (ServerRequest 'Normal request response -> IO (ServerResponse 'Normal response)) -> + ServerHandler request response convertGeneratedServerHandler handler call = - do let call' = call { payload = () } - ServerNormalResponse rsp meta stsCode stsDetails <- - handler (ServerNormalRequest call' (payload call)) - return (rsp, meta, stsCode, stsDetails) - -convertServerHandler :: (Message a, Message b) - => ServerHandler a b - -> ServerHandlerLL + do + let call' = call{payload = ()} + ServerNormalResponse rsp meta stsCode stsDetails <- + handler (ServerNormalRequest call' (payload call)) + return (rsp, meta, stsCode, stsDetails) + +convertServerHandler :: + (Message a, Message b) => + ServerHandler a b -> + ServerHandlerLL convertServerHandler f c = case fromByteString (payload c) of - Left x -> CE.throw (GRPCIODecodeError $ show x) - Right x -> do (y, tm, sc, sd) <- f (fmap (const x) c) - return (toBS y, tm, sc, sd) + Left x -> CE.throw (GRPCIODecodeError $ show x) + Right x -> do + (y, tm, sc, sd) <- f (fmap (const x) c) + return (toBS y, tm, sc, sd) -type ServerReaderHandler a b - = ServerCall (MethodPayload 'ClientStreaming) - -> StreamRecv a - -> IO (Maybe b, MetadataMap, StatusCode, StatusDetails) +type ServerReaderHandler a b = + ServerCall (MethodPayload 'ClientStreaming) -> + StreamRecv a -> + IO (Maybe b, MetadataMap, StatusCode, StatusDetails) convertGeneratedServerReaderHandler :: - (ServerRequest 'ClientStreaming request response -> IO (ServerResponse 'ClientStreaming response)) - -> ServerReaderHandler request response + (ServerRequest 'ClientStreaming request response -> IO (ServerResponse 'ClientStreaming response)) -> + ServerReaderHandler request response convertGeneratedServerReaderHandler handler call recv = - do ServerReaderResponse rsp meta stsCode stsDetails <- - handler (ServerReaderRequest call recv) - return (rsp, meta, stsCode, stsDetails) + do + ServerReaderResponse rsp meta stsCode stsDetails <- + handler (ServerReaderRequest call recv) + return (rsp, meta, stsCode, stsDetails) -convertServerReaderHandler :: (Message a, Message b) - => ServerReaderHandler a b - -> ServerReaderHandlerLL +convertServerReaderHandler :: + (Message a, Message b) => + ServerReaderHandler a b -> + ServerReaderHandlerLL convertServerReaderHandler f c recv = serialize <$> f c (convertRecv recv) where serialize (mmsg, m, sc, sd) = (toBS <$> mmsg, m, sc, sd) type ServerWriterHandler a b = - ServerCall a - -> StreamSend b - -> IO (MetadataMap, StatusCode, StatusDetails) + ServerCall a -> + StreamSend b -> + IO (MetadataMap, StatusCode, StatusDetails) convertGeneratedServerWriterHandler :: - (ServerRequest 'ServerStreaming request response -> IO (ServerResponse 'ServerStreaming response)) - -> ServerWriterHandler request response + (ServerRequest 'ServerStreaming request response -> IO (ServerResponse 'ServerStreaming response)) -> + ServerWriterHandler request response convertGeneratedServerWriterHandler handler call send = - do let call' = call { payload = () } - ServerWriterResponse meta stsCode stsDetails <- - handler (ServerWriterRequest call' (payload call) send) - return (meta, stsCode, stsDetails) - -convertServerWriterHandler :: (Message a, Message b) => - ServerWriterHandler a b - -> ServerWriterHandlerLL + do + let call' = call{payload = ()} + ServerWriterResponse meta stsCode stsDetails <- + handler (ServerWriterRequest call' (payload call) send) + return (meta, stsCode, stsDetails) + +convertServerWriterHandler :: + (Message a, Message b) => + ServerWriterHandler a b -> + ServerWriterHandlerLL convertServerWriterHandler f c send = f (convert <$> c) (convertSend send) where convert bs = case fromByteString bs of - Left x -> CE.throw (GRPCIODecodeError $ show x) + Left x -> CE.throw (GRPCIODecodeError $ show x) Right x -> x -type ServerRWHandler a b - = ServerCall (MethodPayload 'BiDiStreaming) - -> StreamRecv a - -> StreamSend b - -> IO (MetadataMap, StatusCode, StatusDetails) +type ServerRWHandler a b = + ServerCall (MethodPayload 'BiDiStreaming) -> + StreamRecv a -> + StreamSend b -> + IO (MetadataMap, StatusCode, StatusDetails) convertGeneratedServerRWHandler :: - (ServerRequest 'BiDiStreaming request response -> IO (ServerResponse 'BiDiStreaming response)) - -> ServerRWHandler request response + (ServerRequest 'BiDiStreaming request response -> IO (ServerResponse 'BiDiStreaming response)) -> + ServerRWHandler request response convertGeneratedServerRWHandler handler call recv send = - do ServerBiDiResponse meta stsCode stsDetails <- - handler (ServerBiDiRequest call recv send) - return (meta, stsCode, stsDetails) + do + ServerBiDiResponse meta stsCode stsDetails <- + handler (ServerBiDiRequest call recv send) + return (meta, stsCode, stsDetails) -convertServerRWHandler :: (Message a, Message b) - => ServerRWHandler a b - -> ServerRWHandlerLL +convertServerRWHandler :: + (Message a, Message b) => + ServerRWHandler a b -> + ServerRWHandlerLL convertServerRWHandler f c recv send = f c (convertRecv recv) (convertSend send) -convertRecv :: Message a => StreamRecv ByteString -> StreamRecv a +convertRecv :: (Message a) => StreamRecv ByteString -> StreamRecv a convertRecv = fmap $ \e -> do msg <- e case msg of Nothing -> return Nothing Just bs -> case fromByteString bs of - Left x -> Left (GRPCIODecodeError $ show x) - Right x -> return (Just x) + Left x -> Left (GRPCIODecodeError $ show x) + Right x -> return (Just x) -convertSend :: Message a => StreamSend ByteString -> StreamSend a +convertSend :: (Message a) => StreamSend ByteString -> StreamSend a convertSend s = s . toBS -toBS :: Message a => a -> ByteString +toBS :: (Message a) => a -> ByteString toBS = BL.toStrict . toLazyByteString data Handler (a :: GRPCMethodType) where - UnaryHandler :: (Message c, Message d) => MethodName -> ServerHandler c d -> Handler 'Normal + UnaryHandler :: (Message c, Message d) => MethodName -> ServerHandler c d -> Handler 'Normal ClientStreamHandler :: (Message c, Message d) => MethodName -> ServerReaderHandler c d -> Handler 'ClientStreaming ServerStreamHandler :: (Message c, Message d) => MethodName -> ServerWriterHandler c d -> Handler 'ServerStreaming - BiDiStreamHandler :: (Message c, Message d) => MethodName -> ServerRWHandler c d -> Handler 'BiDiStreaming + BiDiStreamHandler :: (Message c, Message d) => MethodName -> ServerRWHandler c d -> Handler 'BiDiStreaming data AnyHandler = forall (a :: GRPCMethodType). AnyHandler (Handler a) @@ -151,19 +174,20 @@ anyHandlerMethodName :: AnyHandler -> MethodName anyHandlerMethodName (AnyHandler m) = handlerMethodName m handlerMethodName :: Handler a -> MethodName -handlerMethodName (UnaryHandler m _) = m +handlerMethodName (UnaryHandler m _) = m handlerMethodName (ClientStreamHandler m _) = m handlerMethodName (ServerStreamHandler m _) = m -handlerMethodName (BiDiStreamHandler m _) = m +handlerMethodName (BiDiStreamHandler m _) = m -- | Handles errors that result from trying to handle a call on the server. -- For each error, takes a different action depending on the severity in the -- context of handling a server call. This also tries to give an indication of -- whether the error is our fault or user error. -handleCallError :: (String -> IO ()) - -- ^ logging function - -> Either GRPCIOError a - -> IO () +handleCallError :: + -- | logging function + (String -> IO ()) -> + Either GRPCIOError a -> + IO () handleCallError _ (Right _) = return () handleCallError _ (Left GRPCIOTimeout) = -- Probably a benign timeout (such as a client disappearing), noop for now. @@ -178,20 +202,22 @@ handleCallError logMsg (Left (GRPCIOHandlerException e)) = handleCallError logMsg (Left x) = logMsg $ show x ++ ": This probably indicates a bug in gRPC-haskell. Please report this error." -loopWError :: Int - -> ServerOptions - -> IO (Either GRPCIOError a) - -> IO () +loopWError :: + Int -> + ServerOptions -> + IO (Either GRPCIOError a) -> + IO () loopWError i o@ServerOptions{..} f = do - when (i `mod` 100 == 0) $ putStrLn $ "i = " ++ show i - f >>= handleCallError optLogger - loopWError (i + 1) o f + when (i `mod` 100 == 0) $ putStrLn $ "i = " ++ show i + f >>= handleCallError optLogger + loopWError (i + 1) o f -- TODO: options for setting initial/trailing metadata -handleLoop :: Server - -> ServerOptions - -> (Handler a, RegisteredMethod a) - -> IO () +handleLoop :: + Server -> + ServerOptions -> + (Handler a, RegisteredMethod a) -> + IO () handleLoop s o (UnaryHandler _ f, rm) = loopWError 0 o $ serverHandleNormalCall s rm mempty $ convertServerHandler f handleLoop s o (ClientStreamHandler _ f, rm) = @@ -202,56 +228,58 @@ handleLoop s o (BiDiStreamHandler _ f, rm) = loopWError 0 o $ serverRW s rm mempty $ convertServerRWHandler f data ServerOptions = ServerOptions - { optNormalHandlers :: [Handler 'Normal] - -- ^ Handlers for unary (non-streaming) calls. + { optNormalHandlers :: [Handler 'Normal] + -- ^ Handlers for unary (non-streaming) calls. , optClientStreamHandlers :: [Handler 'ClientStreaming] - -- ^ Handlers for client streaming calls. + -- ^ Handlers for client streaming calls. , optServerStreamHandlers :: [Handler 'ServerStreaming] - -- ^ Handlers for server streaming calls. - , optBiDiStreamHandlers :: [Handler 'BiDiStreaming] - -- ^ Handlers for bidirectional streaming calls. - , optServerHost :: Host - -- ^ Name of the host the server is running on. - , optServerPort :: Port - -- ^ Port on which to listen for requests. - , optUseCompression :: Bool - -- ^ Whether to use compression when communicating with the client. - , optUserAgentPrefix :: String - -- ^ Optional custom prefix to add to the user agent string. - , optUserAgentSuffix :: String - -- ^ Optional custom suffix to add to the user agent string. - , optInitialMetadata :: MetadataMap - -- ^ Metadata to send at the beginning of each call. - , optSSLConfig :: Maybe ServerSSLConfig - -- ^ Security configuration. - , optLogger :: String -> IO () - -- ^ Logging function to use to log errors in handling calls. + -- ^ Handlers for server streaming calls. + , optBiDiStreamHandlers :: [Handler 'BiDiStreaming] + -- ^ Handlers for bidirectional streaming calls. + , optServerHost :: Host + -- ^ Name of the host the server is running on. + , optServerPort :: Port + -- ^ Port on which to listen for requests. + , optUseCompression :: Bool + -- ^ Whether to use compression when communicating with the client. + , optUserAgentPrefix :: String + -- ^ Optional custom prefix to add to the user agent string. + , optUserAgentSuffix :: String + -- ^ Optional custom suffix to add to the user agent string. + , optInitialMetadata :: MetadataMap + -- ^ Metadata to send at the beginning of each call. + , optSSLConfig :: Maybe ServerSSLConfig + -- ^ Security configuration. + , optLogger :: String -> IO () + -- ^ Logging function to use to log errors in handling calls. , optMaxReceiveMessageLength :: Maybe Natural - -- ^ Maximum length (in bytes) that the service may receive in a single message. + -- ^ Maximum length (in bytes) that the service may receive in a single message. , optMaxMetadataSize :: Maybe Natural - -- ^ Maximum metadata size (in bytes) that the service may receive in a single request. + -- ^ Maximum metadata size (in bytes) that the service may receive in a single request. } defaultOptions :: ServerOptions -defaultOptions = ServerOptions - { optNormalHandlers = [] - , optClientStreamHandlers = [] - , optServerStreamHandlers = [] - , optBiDiStreamHandlers = [] - , optServerHost = "localhost" - , optServerPort = 50051 - , optUseCompression = False - , optUserAgentPrefix = "grpc-haskell/0.0.0" - , optUserAgentSuffix = "" - , optInitialMetadata = mempty - , optSSLConfig = Nothing - , optLogger = hPutStrLn stderr - , optMaxReceiveMessageLength = Nothing - , optMaxMetadataSize = Nothing - } +defaultOptions = + ServerOptions + { optNormalHandlers = [] + , optClientStreamHandlers = [] + , optServerStreamHandlers = [] + , optBiDiStreamHandlers = [] + , optServerHost = "localhost" + , optServerPort = 50051 + , optUseCompression = False + , optUserAgentPrefix = "grpc-haskell/0.0.0" + , optUserAgentSuffix = "" + , optInitialMetadata = mempty + , optSSLConfig = Nothing + , optLogger = hPutStrLn stderr + , optMaxReceiveMessageLength = Nothing + , optMaxMetadataSize = Nothing + } serverLoop :: ServerOptions -> IO () serverLoop _opts = fail "Registered method-based serverLoop NYI" + {- withGRPC $ \grpc -> withServer grpc (mkConfig opts) $ \server -> do diff --git a/src/Network/GRPC/HighLevel/Server/Unregistered.hs b/src/Network/GRPC/HighLevel/Server/Unregistered.hs index b3f2d29d..a3594bfb 100644 --- a/src/Network/GRPC/HighLevel/Server/Unregistered.hs +++ b/src/Network/GRPC/HighLevel/Server/Unregistered.hs @@ -1,47 +1,53 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE GADTs #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedStrings #-} -{-# LANGUAGE RankNTypes #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RankNTypes #-} +{-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} module Network.GRPC.HighLevel.Server.Unregistered where -import Control.Arrow -import Control.Concurrent.MVar (newEmptyMVar, - putMVar, - takeMVar) -import qualified Control.Exception as CE -import Control.Monad -import Data.Foldable (find) -import Network.GRPC.HighLevel.Server -import Network.GRPC.LowLevel -import Network.GRPC.LowLevel.Server (forkServer) -import qualified Network.GRPC.LowLevel.Call.Unregistered as U +import Control.Arrow +import Control.Concurrent.MVar ( + newEmptyMVar, + putMVar, + takeMVar, + ) +import qualified Control.Exception as CE +import Control.Monad +import Data.Foldable (find) +import Network.GRPC.HighLevel.Server +import Network.GRPC.LowLevel +import qualified Network.GRPC.LowLevel.Call.Unregistered as U +import Network.GRPC.LowLevel.Server (forkServer) import qualified Network.GRPC.LowLevel.Server.Unregistered as U -import Proto3.Suite.Class +import Proto3.Suite.Class -dispatchLoop :: Server - -> (String -> IO ()) - -> MetadataMap - -> [Handler 'Normal] - -> [Handler 'ClientStreaming] - -> [Handler 'ServerStreaming] - -> [Handler 'BiDiStreaming] - -> IO () +dispatchLoop :: + Server -> + (String -> IO ()) -> + MetadataMap -> + [Handler 'Normal] -> + [Handler 'ClientStreaming] -> + [Handler 'ServerStreaming] -> + [Handler 'BiDiStreaming] -> + IO () dispatchLoop s logger md hN hC hS hB = forever $ U.withServerCallAsync s $ \sc -> case findHandler sc allHandlers of Just (AnyHandler ah) -> case ah of - UnaryHandler _ h -> unaryHandler sc h + UnaryHandler _ h -> unaryHandler sc h ClientStreamHandler _ h -> csHandler sc h ServerStreamHandler _ h -> ssHandler sc h - BiDiStreamHandler _ h -> bdHandler sc h - Nothing -> unknownHandler sc + BiDiStreamHandler _ h -> bdHandler sc h + Nothing -> unknownHandler sc where - allHandlers = map AnyHandler hN ++ map AnyHandler hC - ++ map AnyHandler hS ++ map AnyHandler hB + allHandlers = + map AnyHandler hN + ++ map AnyHandler hC + ++ map AnyHandler hS + ++ map AnyHandler hB findHandler sc = find ((== U.callMethod sc) . anyHandlerMethodName) @@ -66,7 +72,8 @@ dispatchLoop s logger md hN hC hS hB = handleError :: IO a -> IO () handleError = (handleCallError logger . left herr =<<) . CE.try - where herr (e :: CE.SomeException) = GRPCIOHandlerException (show e) + where + herr (e :: CE.SomeException) = GRPCIOHandlerException (show e) serverLoop :: ServerOptions -> IO () serverLoop ServerOptions{..} = @@ -97,34 +104,34 @@ serverLoop ServerOptions{..} = -- kills the "dispatchLoop" thread and any other thread we -- may have started with "forkServer". done <- newEmptyMVar - launched <- forkServer server $ - dispatchLoop server - optLogger - optInitialMetadata - optNormalHandlers - optClientStreamHandlers - optServerStreamHandlers - optBiDiStreamHandlers - `CE.finally` putMVar done () + launched <- + forkServer server $ + dispatchLoop + server + optLogger + optInitialMetadata + optNormalHandlers + optClientStreamHandlers + optServerStreamHandlers + optBiDiStreamHandlers + `CE.finally` putMVar done () when launched $ takeMVar done where - config = ServerConfig - { host = optServerHost - , port = optServerPort - , methodsToRegisterNormal = [] - , methodsToRegisterClientStreaming = [] - , methodsToRegisterServerStreaming = [] - , methodsToRegisterBiDiStreaming = [] - , serverArgs = - [CompressionAlgArg GrpcCompressDeflate | optUseCompression] - ++ - [ UserAgentPrefix optUserAgentPrefix - , UserAgentSuffix optUserAgentSuffix - ] - ++ - foldMap (pure . MaxReceiveMessageLength) optMaxReceiveMessageLength - ++ - foldMap (pure . MaxMetadataSize) optMaxMetadataSize - , sslConfig = optSSLConfig - } + config = + ServerConfig + { host = optServerHost + , port = optServerPort + , methodsToRegisterNormal = [] + , methodsToRegisterClientStreaming = [] + , methodsToRegisterServerStreaming = [] + , methodsToRegisterBiDiStreaming = [] + , serverArgs = + [CompressionAlgArg GrpcCompressDeflate | optUseCompression] + ++ [ UserAgentPrefix optUserAgentPrefix + , UserAgentSuffix optUserAgentSuffix + ] + ++ foldMap (pure . MaxReceiveMessageLength) optMaxReceiveMessageLength + ++ foldMap (pure . MaxMetadataSize) optMaxMetadataSize + , sslConfig = optSSLConfig + } diff --git a/tests/GeneratedTests.hs b/tests/GeneratedTests.hs index edbe6b4f..298b74f4 100644 --- a/tests/GeneratedTests.hs +++ b/tests/GeneratedTests.hs @@ -1,4 +1,3 @@ - {-# LANGUAGE OverloadedStrings #-} module GeneratedTests where @@ -12,35 +11,42 @@ import Proto3.Suite.DotProto.Generate import Turtle hiding (err) generatedTests :: TestTree -generatedTests = testGroup "Code generator tests" - [ testServerGeneration - , testClientGeneration ] +generatedTests = + testGroup + "Code generator tests" + [ testServerGeneration + , testClientGeneration + ] testServerGeneration :: TestTree testServerGeneration = testCase "server generation" $ do mktree hsTmpDir mktree pyTmpDir - let args = CompileArgs - { includeDir = ["tests"] - , extraInstanceFiles = [] - , inputProto = "simple.proto" - , outputDir = hsTmpDir - , stringType = StringType "Data.Text.Lazy" "Text" - , recordStyle = LargeRecords - } + let args = + CompileArgs + { includeDir = ["tests"] + , extraInstanceFiles = [] + , inputProto = "simple.proto" + , outputDir = hsTmpDir + , stringType = StringType "Data.Text.Lazy" "Text" + , recordStyle = LargeRecords + } compileDotProtoFileOrDie args - do exitCode <- proc "tests/simple-server.sh" [hsTmpDir] empty - exitCode @?= ExitSuccess + do + exitCode <- proc "tests/simple-server.sh" [hsTmpDir] empty + exitCode @?= ExitSuccess - do exitCode <- proc "tests/protoc.sh" [pyTmpDir] empty - exitCode @?= ExitSuccess + do + exitCode <- proc "tests/protoc.sh" [pyTmpDir] empty + exitCode @?= ExitSuccess runManaged $ do serverExitCodeA <- fork (shell (hsTmpDir <> "/simple-server") empty) - clientExitCodeA <- fork - (export "PYTHONPATH" pyTmpDir >> shell "tests/test-client.sh" empty) + clientExitCodeA <- + fork + (export "PYTHONPATH" pyTmpDir >> shell "tests/test-client.sh" empty) liftIO $ do serverExitCode <- liftIO (wait serverExitCodeA) @@ -57,25 +63,29 @@ testClientGeneration = testCase "client generation" $ do mktree hsTmpDir mktree pyTmpDir - let args = CompileArgs - { includeDir = ["tests"] - , extraInstanceFiles = [] - , inputProto = "simple.proto" - , outputDir = hsTmpDir - , stringType = StringType "Data.Text.Lazy" "Text" - , recordStyle = LargeRecords - } + let args = + CompileArgs + { includeDir = ["tests"] + , extraInstanceFiles = [] + , inputProto = "simple.proto" + , outputDir = hsTmpDir + , stringType = StringType "Data.Text.Lazy" "Text" + , recordStyle = LargeRecords + } compileDotProtoFileOrDie args - do exitCode <- proc "tests/simple-client.sh" [hsTmpDir] empty - exitCode @?= ExitSuccess + do + exitCode <- proc "tests/simple-client.sh" [hsTmpDir] empty + exitCode @?= ExitSuccess - do exitCode <- proc "tests/protoc.sh" [pyTmpDir] empty - exitCode @?= ExitSuccess + do + exitCode <- proc "tests/protoc.sh" [pyTmpDir] empty + exitCode @?= ExitSuccess runManaged $ do - serverExitCodeA <- fork - (export "PYTHONPATH" pyTmpDir >> shell "tests/test-server.sh" empty) + serverExitCodeA <- + fork + (export "PYTHONPATH" pyTmpDir >> shell "tests/test-server.sh" empty) clientExitCodeA <- fork (shell (hsTmpDir <> "/simple-client") empty) liftIO $ do @@ -88,6 +98,6 @@ testClientGeneration = testCase "client generation" $ do rmtree hsTmpDir rmtree pyTmpDir -hsTmpDir, pyTmpDir :: IsString a => a +hsTmpDir, pyTmpDir :: (IsString a) => a hsTmpDir = "tests/tmp" pyTmpDir = "tests/py-tmp" diff --git a/tests/Properties.hs b/tests/Properties.hs index 6f727935..ab971200 100644 --- a/tests/Properties.hs +++ b/tests/Properties.hs @@ -1,5 +1,5 @@ -import Test.Tasty -import GeneratedTests +import GeneratedTests +import Test.Tasty main :: IO () -main = defaultMain $ testGroup "GRPC Unit Tests" [ generatedTests ] +main = defaultMain $ testGroup "GRPC Unit Tests" [generatedTests] diff --git a/tests/TestClient.hs b/tests/TestClient.hs index ffbb6834..fffbdaea 100644 --- a/tests/TestClient.hs +++ b/tests/TestClient.hs @@ -1,6 +1,6 @@ -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE ScopedTypeVariables #-} @@ -10,126 +10,150 @@ import Prelude hiding (sum) import Simple +import Control.Arrow import Control.Concurrent import Control.Concurrent.MVar +import Control.Exception import Control.Monad import Control.Monad.IO.Class -import Control.Exception -import Control.Arrow -import Data.Monoid import Data.Foldable (sum) +import Data.Monoid import Data.String -import Data.Word import Data.Vector (fromList) +import Data.Word -import Network.GRPC.LowLevel import Network.GRPC.HighLevel.Client +import Network.GRPC.LowLevel import Proto3.Suite import System.Random import Test.Tasty -import Test.Tasty.HUnit ((@?=), assertFailure, testCase) +import Test.Tasty.HUnit (assertFailure, testCase, (@?=)) testNormalCall SimpleService{..} = testCase "Normal call" $ - do randoms <- fromList <$> replicateM 1000 (randomRIO (1, 1000)) - let req = SimpleServiceRequest "NormalRequest" randoms - res <- simpleServicenormalCall - (ClientNormalRequest req 10 mempty) - case res of - ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) - ClientNormalResponse SimpleServiceResponse{..} _ _ stsCode _ -> - do stsCode @?= StatusOk - simpleServiceResponseResponse @?= "NormalRequest" - simpleServiceResponseNum @?= sum randoms + do + randoms <- fromList <$> replicateM 1000 (randomRIO (1, 1000)) + let req = SimpleServiceRequest "NormalRequest" randoms + res <- + simpleServicenormalCall + (ClientNormalRequest req 10 mempty) + case res of + ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) + ClientNormalResponse SimpleServiceResponse{..} _ _ stsCode _ -> + do + stsCode @?= StatusOk + simpleServiceResponseResponse @?= "NormalRequest" + simpleServiceResponseNum @?= sum randoms testClientStreamingCall SimpleService{..} = testCase "Client-streaming call" $ - do iterationCount <- randomRIO (5, 50) - v <- newEmptyMVar - res <- simpleServiceclientStreamingCall . ClientWriterRequest 10 mempty $ \send -> - do (finalName, totalSum) <- - fmap ((mconcat *** (sum . mconcat)) . unzip) . - replicateM iterationCount $ - do randoms <- fromList <$> replicateM 1000 (randomRIO (1, 1000)) - name <- fromString <$> replicateM 10 (randomRIO ('a', 'z')) - send (SimpleServiceRequest name randoms) - pure (name, randoms) - putMVar v (finalName, totalSum) - - (finalName, totalSum) <- readMVar v - case res of - ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) - ClientWriterResponse Nothing _ _ _ _ -> assertFailure "No response received" - ClientWriterResponse (Just SimpleServiceResponse{..}) _ _ stsCode _ -> - do stsCode @?= StatusOk - simpleServiceResponseResponse @?= finalName - simpleServiceResponseNum @?= totalSum + do + iterationCount <- randomRIO (5, 50) + v <- newEmptyMVar + res <- simpleServiceclientStreamingCall . ClientWriterRequest 10 mempty $ \send -> + do + (finalName, totalSum) <- + fmap ((mconcat *** (sum . mconcat)) . unzip) + . replicateM iterationCount + $ do + randoms <- fromList <$> replicateM 1000 (randomRIO (1, 1000)) + name <- fromString <$> replicateM 10 (randomRIO ('a', 'z')) + send (SimpleServiceRequest name randoms) + pure (name, randoms) + putMVar v (finalName, totalSum) + + (finalName, totalSum) <- readMVar v + case res of + ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) + ClientWriterResponse Nothing _ _ _ _ -> assertFailure "No response received" + ClientWriterResponse (Just SimpleServiceResponse{..}) _ _ stsCode _ -> + do + stsCode @?= StatusOk + simpleServiceResponseResponse @?= finalName + simpleServiceResponseNum @?= totalSum testServerStreamingCall SimpleService{..} = testCase "Server-streaming call" $ - do numCount <- randomRIO (50, 500) - nums <- replicateM numCount randomIO - - let checkResults [] recv = - do res <- recv - case res of - Left err -> assertFailure ("recv error: " <> show err) - Right Nothing -> pure () - Right (Just _) -> assertFailure "recv: elements past end of stream" - checkResults (expNum:nums) recv = - do res <- recv - case res of - Left err -> assertFailure ("recv error: " <> show err) - Right Nothing -> assertFailure ("recv: stream ended earlier than expected") - Right (Just (SimpleServiceResponse response num)) -> - do response @?= "Test" - num @?= expNum - checkResults nums recv - res <- simpleServiceserverStreamingCall $ - ClientReaderRequest (SimpleServiceRequest "Test" (fromList nums)) 10 mempty - (\_ _ -> checkResults nums) - case res of - ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) - ClientReaderResponse _ sts _ -> - sts @?= StatusOk + do + numCount <- randomRIO (50, 500) + nums <- replicateM numCount randomIO + + let checkResults [] recv = + do + res <- recv + case res of + Left err -> assertFailure ("recv error: " <> show err) + Right Nothing -> pure () + Right (Just _) -> assertFailure "recv: elements past end of stream" + checkResults (expNum : nums) recv = + do + res <- recv + case res of + Left err -> assertFailure ("recv error: " <> show err) + Right Nothing -> assertFailure ("recv: stream ended earlier than expected") + Right (Just (SimpleServiceResponse response num)) -> + do + response @?= "Test" + num @?= expNum + checkResults nums recv + res <- + simpleServiceserverStreamingCall $ + ClientReaderRequest + (SimpleServiceRequest "Test" (fromList nums)) + 10 + mempty + (\_ _ -> checkResults nums) + case res of + ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) + ClientReaderResponse _ sts _ -> + sts @?= StatusOk testBiDiStreamingCall SimpleService{..} = testCase "Bidi-streaming call" $ - do let handleRequests (0 :: Int) _ _ done = done >> pure () - handleRequests n recv send done = - do numCount <- randomRIO (10, 1000) - nums <- fromList <$> replicateM numCount (randomRIO (1, 1000)) - testName <- fromString <$> replicateM 10 (randomRIO ('a', 'z')) - send (SimpleServiceRequest testName nums) - - res <- recv - case res of - Left err -> assertFailure ("recv error: " <> show err) - Right Nothing -> pure () - Right (Just (SimpleServiceResponse name total)) -> - do name @?= testName - total @?= sum nums - handleRequests (n - 1) recv send done - - iterations <- randomRIO (50, 500) - - res <- simpleServicebiDiStreamingCall $ - ClientBiDiRequest 10 mempty (\_ _ -> handleRequests iterations) - case res of - ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) - ClientBiDiResponse _ sts _ -> - sts @?= StatusOk + do + let handleRequests (0 :: Int) _ _ done = done >> pure () + handleRequests n recv send done = + do + numCount <- randomRIO (10, 1000) + nums <- fromList <$> replicateM numCount (randomRIO (1, 1000)) + testName <- fromString <$> replicateM 10 (randomRIO ('a', 'z')) + send (SimpleServiceRequest testName nums) + + res <- recv + case res of + Left err -> assertFailure ("recv error: " <> show err) + Right Nothing -> pure () + Right (Just (SimpleServiceResponse name total)) -> + do + name @?= testName + total @?= sum nums + handleRequests (n - 1) recv send done + + iterations <- randomRIO (50, 500) + + res <- + simpleServicebiDiStreamingCall $ + ClientBiDiRequest 10 mempty (\_ _ -> handleRequests iterations) + case res of + ClientErrorResponse err -> assertFailure ("ClientErrorResponse: " <> show err) + ClientBiDiResponse _ sts _ -> + sts @?= StatusOk main :: IO () main = do threadDelay 10000000 withGRPC $ \grpc -> - withClient grpc (ClientConfig "localhost:50051" [] Nothing Nothing) $ \client -> - do service@SimpleService{..} <- simpleServiceClient client - - (defaultMain $ testGroup "Send gRPC requests" - [ testNormalCall service - , testClientStreamingCall service - , testServerStreamingCall service - , testBiDiStreamingCall service ]) `finally` - (simpleServicedone (ClientNormalRequest SimpleServiceDone 10 mempty)) + withClient grpc (ClientConfig "localhost:50051" [] Nothing Nothing) $ \client -> + do + service@SimpleService{..} <- simpleServiceClient client + + ( defaultMain $ + testGroup + "Send gRPC requests" + [ testNormalCall service + , testClientStreamingCall service + , testServerStreamingCall service + , testBiDiStreamingCall service + ] + ) + `finally` (simpleServicedone (ClientNormalRequest SimpleServiceDone 10 mempty)) diff --git a/tests/TestServer.hs b/tests/TestServer.hs index de49a414..296fcd60 100644 --- a/tests/TestServer.hs +++ b/tests/TestServer.hs @@ -1,6 +1,6 @@ -{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} +{-# LANGUAGE OverloadedStrings #-} module Main where @@ -13,62 +13,75 @@ import Control.Concurrent.MVar import Control.Monad import Control.Monad.IO.Class -import Data.Monoid import Data.Foldable (sum) +import Data.Monoid import Data.String -import Network.GRPC.LowLevel -import Network.GRPC.HighLevel.Server import Network.GRPC.HighLevel.Generated (defaultServiceOptions) +import Network.GRPC.HighLevel.Server +import Network.GRPC.LowLevel handleNormalCall :: ServerRequest 'Normal SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'Normal SimpleServiceResponse) handleNormalCall (ServerNormalRequest meta (SimpleServiceRequest request nums)) = pure (ServerNormalResponse (SimpleServiceResponse request result) mempty StatusOk (StatusDetails "")) - where result = sum nums + where + result = sum nums handleClientStreamingCall :: ServerRequest 'ClientStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'ClientStreaming SimpleServiceResponse) handleClientStreamingCall (ServerReaderRequest call recvRequest) = go 0 "" - where go sumAccum nameAccum = - recvRequest >>= \req -> - case req of - Left ioError -> pure (ServerReaderResponse Nothing mempty StatusCancelled (StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError)))) - Right Nothing -> - pure (ServerReaderResponse (Just (SimpleServiceResponse nameAccum sumAccum)) mempty StatusOk (StatusDetails "")) - Right (Just (SimpleServiceRequest name nums)) -> - go (sumAccum + sum nums) (nameAccum <> name) + where + go sumAccum nameAccum = + recvRequest >>= \req -> + case req of + Left ioError -> pure (ServerReaderResponse Nothing mempty StatusCancelled (StatusDetails ("handleClientStreamingCall: IO error: " <> fromString (show ioError)))) + Right Nothing -> + pure (ServerReaderResponse (Just (SimpleServiceResponse nameAccum sumAccum)) mempty StatusOk (StatusDetails "")) + Right (Just (SimpleServiceRequest name nums)) -> + go (sumAccum + sum nums) (nameAccum <> name) handleServerStreamingCall :: ServerRequest 'ServerStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'ServerStreaming SimpleServiceResponse) handleServerStreamingCall (ServerWriterRequest call (SimpleServiceRequest requestName nums) sendResponse) = go - where go = do forM_ nums $ \num -> - sendResponse (SimpleServiceResponse requestName num) - pure (ServerWriterResponse mempty StatusOk (StatusDetails "")) + where + go = do + forM_ nums $ \num -> + sendResponse (SimpleServiceResponse requestName num) + pure (ServerWriterResponse mempty StatusOk (StatusDetails "")) handleBiDiStreamingCall :: ServerRequest 'BiDiStreaming SimpleServiceRequest SimpleServiceResponse -> IO (ServerResponse 'BiDiStreaming SimpleServiceResponse) handleBiDiStreamingCall (ServerBiDiRequest call recvRequest sendResponse) = go - where go = recvRequest >>= \req -> - case req of - Left ioError -> - pure (ServerBiDiResponse mempty StatusCancelled (StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError)))) - Right Nothing -> - pure (ServerBiDiResponse mempty StatusOk (StatusDetails "")) - Right (Just (SimpleServiceRequest name nums)) -> - do sendResponse (SimpleServiceResponse name (sum nums)) - go + where + go = + recvRequest >>= \req -> + case req of + Left ioError -> + pure (ServerBiDiResponse mempty StatusCancelled (StatusDetails ("handleBiDiStreamingCall: IO error: " <> fromString (show ioError)))) + Right Nothing -> + pure (ServerBiDiResponse mempty StatusOk (StatusDetails "")) + Right (Just (SimpleServiceRequest name nums)) -> + do + sendResponse (SimpleServiceResponse name (sum nums)) + go handleDone :: MVar () -> ServerRequest 'Normal SimpleServiceDone SimpleServiceDone -> IO (ServerResponse 'Normal SimpleServiceDone) handleDone exitVar (ServerNormalRequest _ req) = - do forkIO (threadDelay 5000 >> putMVar exitVar ()) - pure (ServerNormalResponse req mempty StatusOk (StatusDetails "")) + do + forkIO (threadDelay 5000 >> putMVar exitVar ()) + pure (ServerNormalResponse req mempty StatusOk (StatusDetails "")) main :: IO () -main = do exitVar <- newEmptyMVar +main = do + exitVar <- newEmptyMVar - forkIO $ simpleServiceServer (SimpleService - { simpleServicedone = handleDone exitVar - , simpleServicenormalCall = handleNormalCall - , simpleServiceclientStreamingCall = handleClientStreamingCall - , simpleServiceserverStreamingCall = handleServerStreamingCall - , simpleServicebiDiStreamingCall = handleBiDiStreamingCall }) - defaultServiceOptions + forkIO $ + simpleServiceServer + ( SimpleService + { simpleServicedone = handleDone exitVar + , simpleServicenormalCall = handleNormalCall + , simpleServiceclientStreamingCall = handleClientStreamingCall + , simpleServiceserverStreamingCall = handleServerStreamingCall + , simpleServicebiDiStreamingCall = handleBiDiStreamingCall + } + ) + defaultServiceOptions - takeMVar exitVar + takeMVar exitVar