Skip to content

Commit c06fe27

Browse files
More smart receiver buffer
1 parent 1b0b59d commit c06fe27

File tree

6 files changed

+46
-11
lines changed

6 files changed

+46
-11
lines changed

src/Database/PostgreSQL/Driver/Connection.hs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,12 @@ authorize rawConn settings = do
161161
readAuthResponse = do
162162
-- 1024 should be enough for the auth response from a server at
163163
-- the startup phase.
164-
resp <- rReceive rawConn 1024
164+
resp <- rReceive rawConn mempty 1024
165165
case runDecode decodeAuthResponse resp of
166166
(rest, r) -> case r of
167167
AuthenticationOk ->
168168
parseParameters
169-
(\bs -> (bs <>) <$> rReceive rawConn 1024) rest
169+
(\bs -> rReceive rawConn bs 1024) rest
170170
AuthenticationCleartextPassword ->
171171
performPasswordAuth makePlainPassword
172172
AuthenticationMD5Password (MD5Salt salt) ->
@@ -219,8 +219,10 @@ buildConnection rawConn connParams receiverAction = do
219219
}
220220

221221
-- | Parses connection parameters.
222-
parseParameters :: (B.ByteString -> IO B.ByteString)
223-
-> B.ByteString -> IO (Either Error ConnectionParameters)
222+
parseParameters
223+
:: (B.ByteString -> IO B.ByteString)
224+
-> B.ByteString
225+
-> IO (Either Error ConnectionParameters)
224226
parseParameters action str = Right <$> do
225227
dict <- parseDict str HM.empty
226228
serverVersion <- eitherToProtocolEx . parseServerVersion =<<
@@ -261,7 +263,7 @@ receiverThread :: RawConnection -> InDataChan -> IO ()
261263
receiverThread rawConn dataChan = loopExtractDataRows
262264
-- TODO
263265
-- dont append strings. Allocate buffer manually and use unsafeReceive
264-
(\bs -> (bs <>) <$> rReceive rawConn 4096)
266+
(\bs -> rReceive rawConn bs 4096)
265267
(writeChan dataChan . Right)
266268

267269
-- | Any exception prevents thread from future work.
@@ -279,7 +281,7 @@ receiverThreadCommon rawConn chan msgFilter ntfHandler = go ""
279281

280282
-- TODO
281283
-- dont append strings. Allocate buffer manually and use unsafeReceive
282-
readMoreAction = (\bs -> (bs <>) <$> rReceive rawConn 4096)
284+
readMoreAction = (\bs -> rReceive rawConn bs 4096)
283285
handler msg = do
284286
dispatchIfNotification msg ntfHandler
285287
when (msgFilter msg) $ writeChan chan $ Right msg

src/Database/PostgreSQL/Driver/Error.hs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@ module Database.PostgreSQL.Driver.Error
77
, ReceiverException(..)
88
, IncorrectUsage
99
, ProtocolException
10+
, PeerClosedConnection
1011
-- * helpers
1112
, throwIncorrectUsage
1213
, throwProtocolEx
14+
, throwClosedException
1315
, eitherToProtocolEx
1416
, throwErrorInIO
1517
, throwAuthErrorInIO
@@ -61,18 +63,30 @@ instance Exception ProtocolException where
6163
displayException (ProtocolException msg) =
6264
"Exception in protocol, " ++ BS.unpack msg
6365

66+
-- | Exception throw when remote peer closes connections.
67+
data PeerClosedConnection = PeerClosedConnection
68+
deriving (Show)
69+
70+
instance Exception PeerClosedConnection where
71+
displayException _ = "Remote peer closed the connection"
72+
6473
throwIncorrectUsage :: ByteString -> IO a
6574
throwIncorrectUsage = throwIO . IncorrectUsage
6675

6776
throwProtocolEx :: ByteString -> IO a
6877
throwProtocolEx = throwIO . ProtocolException
6978

79+
throwClosedException :: IO a
80+
throwClosedException = throwIO PeerClosedConnection
81+
7082
eitherToProtocolEx :: Either ByteString a -> IO a
7183
eitherToProtocolEx = either throwProtocolEx pure
7284

85+
-- TODO rename without throw since it actually does not throw exceptions
7386
throwErrorInIO :: Error -> IO (Either Error a)
7487
throwErrorInIO = pure . Left
7588

89+
-- TODO rename without throw since it actually does not throw exceptions
7690
throwAuthErrorInIO :: AuthError -> IO (Either Error a)
7791
throwAuthErrorInIO = pure . Left . AuthError
7892

src/Database/PostgreSQL/Driver/Query.hs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ sendBatchEndBy msg conn qs = do
8282
batch <- constructBatch conn qs
8383
sendEncode conn $ batch <> encodeClientMessage msg
8484

85+
-- INFO about invalid statement in cache
8586
constructBatch :: Connection -> [Query] -> IO Encode
8687
constructBatch conn = fmap fold . traverse constructSingle
8788
where

src/Database/PostgreSQL/Driver/RawConnection.hs

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,23 @@ module Database.PostgreSQL.Driver.RawConnection
44
, createRawConnection
55
) where
66

7-
import Control.Monad (void)
7+
import Control.Monad (void, when)
88
import Control.Exception (bracketOnError, try)
99
import Safe (headMay)
1010
import Data.Monoid ((<>))
11+
import Foreign (castPtr, plusPtr)
1112
import System.Socket (socket, AddressInfo(..), getAddressInfo, socketAddress,
1213
aiV4Mapped, AddressInfoException, Socket, connect,
1314
close, receive, send)
15+
import System.Socket.Unsafe (unsafeReceive)
1416
import System.Socket.Family.Inet (Inet)
1517
import System.Socket.Type.Stream (Stream, sendAll)
1618
import System.Socket.Protocol.TCP (TCP)
1719
import System.Socket.Protocol.Default (Default)
1820
import System.Socket.Family.Unix (Unix, socketAddressUnixPath)
1921
import qualified Data.ByteString as B
22+
import qualified Data.ByteString.Internal as B
23+
import qualified Data.ByteString.Unsafe as B
2024
import qualified Data.ByteString.Char8 as BS(pack)
2125

2226
import Database.PostgreSQL.Driver.Error
@@ -27,7 +31,8 @@ data RawConnection = RawConnection
2731
{ rFlush :: IO ()
2832
, rClose :: IO ()
2933
, rSend :: B.ByteString -> IO ()
30-
, rReceive :: Int -> IO B.ByteString
34+
-- ByteString that should be prepended to received ByteString
35+
, rReceive :: B.ByteString -> Int -> IO B.ByteString
3136
}
3237

3338
defaultUnixPathDirectory :: B.ByteString
@@ -75,6 +80,17 @@ constructRawConnection s = RawConnection
7580
{ rFlush = pure ()
7681
, rClose = close s
7782
, rSend = \msg -> void $ sendAll s msg mempty
78-
, rReceive = \n -> receive s n mempty
83+
, rReceive = rawReceive s
7984
}
8085

86+
{-# INLINE rawReceive #-}
87+
rawReceive :: Socket f Stream p -> B.ByteString -> Int -> IO B.ByteString
88+
rawReceive s bs n = B.unsafeUseAsCStringLen bs $ \(prevPtr, prevLen) ->
89+
let bufSize = prevLen + n
90+
in B.createUptoN bufSize $ \bufPtr -> do
91+
B.memcpy bufPtr (castPtr prevPtr) prevLen
92+
len <- unsafeReceive s (bufPtr `plusPtr` prevLen)
93+
(fromIntegral bufSize) mempty
94+
-- Received empty string means closed connection by the remote host
95+
when (len == 0) throwClosedException
96+
pure $ prevLen + fromIntegral len

src/Database/PostgreSQL/Driver/StatementStorage.hs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ import Database.PostgreSQL.Protocol.Types
1010

1111
-- | Prepared statement storage
1212
data StatementStorage = StatementStorage
13-
(H.BasicHashTable StatementSQL StatementName) (IORef Word)
13+
!(H.BasicHashTable StatementSQL StatementName) !(IORef Word)
1414

1515
-- | Cache policy about prepared statements.
1616
data CachePolicy
@@ -24,6 +24,8 @@ newStatementStorage = StatementStorage <$> H.new <*> newIORef 0
2424
lookupStatement :: StatementStorage -> StatementSQL -> IO (Maybe StatementName)
2525
lookupStatement (StatementStorage table _) = H.lookup table
2626

27+
-- TODO place right name
28+
-- TODO info about exceptions and mask
2729
storeStatement :: StatementStorage -> StatementSQL -> IO StatementName
2830
storeStatement (StatementStorage table counter) stmt = do
2931
n <- readIORef counter

src/Database/PostgreSQL/Protocol/DataRows.hs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ module Database.PostgreSQL.Protocol.DataRows
66
, decodeOneRow
77
) where
88

9-
import Data.Monoid ((<>))
9+
import Data.Monoid ((<>))
1010
import Data.Word (Word8, byteSwap32)
1111
import Foreign (peek, peekByteOff, castPtr)
1212
import qualified Data.ByteString as B

0 commit comments

Comments
 (0)