Skip to content

Update to Z.Haskell, fix auth issues, various updates #40

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
197 changes: 58 additions & 139 deletions Database/MySQL/Base.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ but you shouldn't try to catch them if you don't have a recovery plan,
for example: there's no meaning to catch a 'ERRException' during authentication unless you want to try different passwords.
By using this library you will meet:

* 'NetworkException': underline network is broken.
* 'UnconsumedResultSet': you should consume previous resultset before sending new command.
* 'ERRException': you receive a 'ERR' packet when you shouldn't.
* 'UnexpectedPacket': you receive a unexpected packet when you shouldn't.
Expand All @@ -32,23 +31,19 @@ module Database.MySQL.Base
, defaultConnectInfoMB4
, connect
, connectDetail
, close
, ping
-- * Direct query
, execute
, executeMany
, executeMany_
, execute_
, query_
, queryVector_
, query
, queryVector
-- * Prepared query statement
, prepareStmt
, prepareStmtDetail
, executeStmt
, queryStmt
, queryStmtVector
, closeStmt
, resetStmt
-- * Helpers
Expand All @@ -58,13 +53,10 @@ module Database.MySQL.Base
, Query(..)
, renderParams
, command
, Stream.skipToEof
-- * Exceptions
, NetworkException(..)
, UnconsumedResultSet(..)
, ERRException(..)
, UnexpectedPacket(..)
, DecodePacketException(..)
, WrongParamsCount(..)
-- * MySQL protocol
, module Database.MySQL.Protocol.Auth
Expand All @@ -75,21 +67,20 @@ module Database.MySQL.Base
) where

import Control.Applicative
import Control.Exception (mask, onException, throwIO)
import Control.Monad
import qualified Data.ByteString.Lazy as L
import Data.IORef (writeIORef)
import Database.MySQL.Connection
import Database.MySQL.Protocol.Auth
import Database.MySQL.Protocol.ColumnDef
import Database.MySQL.Protocol.Command
import Database.MySQL.Protocol.MySQLValue
import Database.MySQL.Protocol.Packet

import Database.MySQL.Query
import System.IO.Streams (InputStream)
import qualified System.IO.Streams as Stream
import qualified Data.Vector as V
import Z.Data.ASCII
import qualified Z.Data.Vector as V
import qualified Z.Data.Text as T
import Z.IO
import Z.IO.BIO (Source, sourceFromIO)

--------------------------------------------------------------------------------

Expand All @@ -100,11 +91,9 @@ import qualified Data.Vector as V
-- and you should consider using prepared statement if this's not an one shot query.
--
execute :: QueryParam p => MySQLConn -> Query -> [p] -> IO OK
{-# INLINABLE execute #-}
execute conn qry params = execute_ conn (renderParams qry params)

{-# SPECIALIZE execute :: MySQLConn -> Query -> [MySQLValue] -> IO OK #-}
{-# SPECIALIZE execute :: MySQLConn -> Query -> [Param] -> IO OK #-}

-- | Execute a multi-row query which don't return result-set.
--
-- Leverage MySQL's multi-statement support to do batch insert\/update\/delete,
Expand All @@ -114,15 +103,13 @@ execute conn qry params = execute_ conn (renderParams qry params)
-- @since 0.2.0.0
--
executeMany :: QueryParam p => MySQLConn -> Query -> [[p]] -> IO [OK]
executeMany conn@(MySQLConn is os _ _) qry paramsList = do
{-# INLINABLE executeMany #-}
executeMany conn@(MySQLConn is os _) qry paramsList = do
guardUnconsumed conn
let qry' = L.intercalate ";" $ map (fromQuery . renderParams qry) paramsList
let qry' = V.intercalate ";" $ map (fromQuery . renderParams qry) paramsList
writeCommand (COM_QUERY qry') os
mapM (\ _ -> waitCommandReply is) paramsList

{-# SPECIALIZE executeMany :: MySQLConn -> Query -> [[MySQLValue]] -> IO [OK] #-}
{-# SPECIALIZE executeMany :: MySQLConn -> Query -> [[Param]] -> IO [OK] #-}

-- | Execute multiple querys (without param) which don't return result-set.
--
-- This's useful when your want to execute multiple SQLs without params, e.g. from a
Expand All @@ -131,7 +118,8 @@ executeMany conn@(MySQLConn is os _ _) qry paramsList = do
-- @since 0.8.4.0
--
executeMany_ :: MySQLConn -> Query -> IO [OK]
executeMany_ conn@(MySQLConn is os _ _) qry = do
{-# INLINABLE executeMany_ #-}
executeMany_ conn@(MySQLConn is os _) qry = do
guardUnconsumed conn
writeCommand (COM_QUERY (fromQuery qry)) os
waitCommandReplys is
Expand All @@ -147,118 +135,75 @@ execute_ conn (Query qry) = command conn (COM_QUERY qry)
-- the same 'MySQLConn', or an 'UnconsumedResultSet' will be thrown.
-- if you want to skip the result-set, use 'Stream.skipToEof'.
--
query :: QueryParam p => MySQLConn -> Query -> [p] -> IO ([ColumnDef], InputStream [MySQLValue])
query :: QueryParam p => MySQLConn -> Query -> [p] -> IO (V.Vector ColumnDef, Source (V.Vector MySQLValue))
query conn qry params = query_ conn (renderParams qry params)

{-# SPECIALIZE query :: MySQLConn -> Query -> [MySQLValue] -> IO ([ColumnDef], InputStream [MySQLValue]) #-}
{-# SPECIALIZE query :: MySQLConn -> Query -> [Param] -> IO ([ColumnDef], InputStream [MySQLValue]) #-}

-- | 'V.Vector' version of 'query'.
--
-- @since 0.5.1.0
--
queryVector :: QueryParam p => MySQLConn -> Query -> [p] -> IO (V.Vector ColumnDef, InputStream (V.Vector MySQLValue))
queryVector conn qry params = queryVector_ conn (renderParams qry params)

{-# SPECIALIZE queryVector :: MySQLConn -> Query -> [MySQLValue] -> IO (V.Vector ColumnDef, InputStream (V.Vector MySQLValue)) #-}
{-# SPECIALIZE queryVector :: MySQLConn -> Query -> [Param] -> IO (V.Vector ColumnDef, InputStream (V.Vector MySQLValue)) #-}
readFields :: HasCallStack => Int -> BufferedInput -> IO (V.Vector ColumnDef)
{-# INLINABLE readFields #-}
readFields len is = V.replicateM len (decodeFromPacket decodeField =<< readPacket is)

-- | Execute a MySQL query which return a result-set.
--
query_ :: MySQLConn -> Query -> IO ([ColumnDef], InputStream [MySQLValue])
query_ conn@(MySQLConn is os _ consumed) (Query qry) = do
query_ :: HasCallStack => MySQLConn -> Query -> IO (V.Vector ColumnDef, Source (V.Vector MySQLValue))
query_ conn@(MySQLConn is os consumed) (Query qry) = do
guardUnconsumed conn
writeCommand (COM_QUERY qry) os
p <- readPacket is
if isERR p
then decodeFromPacket p >>= throwIO . ERRException
else do
len <- getFromPacket getLenEncInt p
fields <- replicateM len $ (decodeFromPacket <=< readPacket) is
_ <- readPacket is -- eof packet, we don't verify this though
writeIORef consumed False
rows <- Stream.makeInputStream $ do
len <- decodeFromPacket decodeLenEncInt p
fields <- readFields len is
_ <- readPacket is -- eof packet, we don't verify this though
writeIORef consumed False
let rows = sourceFromIO $ do
q <- readPacket is
if | isEOF q -> writeIORef consumed True >> return Nothing
| isERR q -> decodeFromPacket q >>= throwIO . ERRException
| otherwise -> Just <$> getFromPacket (getTextRow fields) q
return (fields, rows)

-- | 'V.Vector' version of 'query_'.
--
-- @since 0.5.1.0
--
queryVector_ :: MySQLConn -> Query -> IO (V.Vector ColumnDef, InputStream (V.Vector MySQLValue))
queryVector_ conn@(MySQLConn is os _ consumed) (Query qry) = do
guardUnconsumed conn
writeCommand (COM_QUERY qry) os
p <- readPacket is
if isERR p
then decodeFromPacket p >>= throwIO . ERRException
else do
len <- getFromPacket getLenEncInt p
fields <- V.replicateM len $ (decodeFromPacket <=< readPacket) is
_ <- readPacket is -- eof packet, we don't verify this though
writeIORef consumed False
rows <- Stream.makeInputStream $ do
q <- readPacket is
if | isEOF q -> writeIORef consumed True >> return Nothing
| isERR q -> decodeFromPacket q >>= throwIO . ERRException
| otherwise -> Just <$> getFromPacket (getTextRowVector fields) q
return (fields, rows)
if isEOF q
then writeIORef consumed True >> return Nothing
else Just <$!> decodeFromPacket (decodeTextRow fields) q
return (fields, rows)

-- | Ask MySQL to prepare a query statement.
--
prepareStmt :: MySQLConn -> Query -> IO StmtID
prepareStmt conn@(MySQLConn is os _ _) (Query stmt) = do
prepareStmt :: HasCallStack => MySQLConn -> Query -> IO StmtID
prepareStmt conn@(MySQLConn is os _) (Query stmt) = do
guardUnconsumed conn
writeCommand (COM_STMT_PREPARE stmt) os
p <- readPacket is
if isERR p
then decodeFromPacket p >>= throwIO . ERRException
else do
StmtPrepareOK stid colCnt paramCnt _ <- getFromPacket getStmtPrepareOK p
_ <- replicateM_ paramCnt (readPacket is)
_ <- unless (paramCnt == 0) (void (readPacket is)) -- EOF
_ <- replicateM_ colCnt (readPacket is)
_ <- unless (colCnt == 0) (void (readPacket is)) -- EOF
return stid
StmtPrepareOK stid colCnt paramCnt _ <- decodeFromPacket decodeStmtPrepareOK p
_ <- replicateM_ (fromIntegral paramCnt) (readPacket is)
_ <- unless (paramCnt == 0) (void (readPacket is)) -- EOF
_ <- replicateM_ (fromIntegral colCnt) (readPacket is)
_ <- unless (colCnt == 0) (void (readPacket is)) -- EOF
return stid

-- | Ask MySQL to prepare a query statement.
--
-- All details from @COM_STMT_PREPARE@ Response are returned: the 'StmtPrepareOK' packet,
-- params's 'ColumnDef', result's 'ColumnDef'.
--
prepareStmtDetail :: MySQLConn -> Query -> IO (StmtPrepareOK, [ColumnDef], [ColumnDef])
prepareStmtDetail conn@(MySQLConn is os _ _) (Query stmt) = do
prepareStmtDetail :: HasCallStack => MySQLConn -> Query -> IO (StmtPrepareOK, V.Vector ColumnDef, V.Vector ColumnDef)
prepareStmtDetail conn@(MySQLConn is os _) (Query stmt) = do
guardUnconsumed conn
writeCommand (COM_STMT_PREPARE stmt) os
p <- readPacket is
if isERR p
then decodeFromPacket p >>= throwIO . ERRException
else do
sOK@(StmtPrepareOK _ colCnt paramCnt _) <- getFromPacket getStmtPrepareOK p
pdefs <- replicateM paramCnt ((decodeFromPacket <=< readPacket) is)
_ <- unless (paramCnt == 0) (void (readPacket is)) -- EOF
cdefs <- replicateM colCnt ((decodeFromPacket <=< readPacket) is)
_ <- unless (colCnt == 0) (void (readPacket is)) -- EOF
return (sOK, pdefs, cdefs)
sOK@(StmtPrepareOK _ colCnt paramCnt _) <- decodeFromPacket decodeStmtPrepareOK p
pdefs <- readFields (fromIntegral paramCnt) is
_ <- unless (paramCnt == 0) (void (readPacket is)) -- EOF
cdefs <- readFields (fromIntegral colCnt) is
_ <- unless (colCnt == 0) (void (readPacket is)) -- EOF
return (sOK, pdefs, cdefs)

-- | Ask MySQL to closed a query statement.
--
closeStmt :: MySQLConn -> StmtID -> IO ()
closeStmt (MySQLConn _ os _ _) stid = do
closeStmt (MySQLConn _ os _) stid = do
writeCommand (COM_STMT_CLOSE stid) os

-- | Ask MySQL to reset a query statement, all previous resultset will be cleared.
--
-- This function can be used when you want to stop a long running query from another thread.
-- Which will lead a thread running `queryStmt` reach its EOF.
resetStmt :: MySQLConn -> StmtID -> IO ()
resetStmt (MySQLConn is os _ consumed) stid = do
resetStmt (MySQLConn _ os _) stid = do
writeCommand (COM_STMT_RESET stid) os -- previous result-set may still be unconsumed
p <- readPacket is
if isERR p
then decodeFromPacket p >>= throwIO . ERRException
else writeIORef consumed True

-- | Execute prepared query statement with parameters, expecting no resultset.
--
Expand All @@ -270,49 +215,23 @@ executeStmt conn stid params =
--
-- Rules about 'UnconsumedResultSet' applied here too.
--
queryStmt :: MySQLConn -> StmtID -> [MySQLValue] -> IO ([ColumnDef], InputStream [MySQLValue])
queryStmt conn@(MySQLConn is os _ consumed) stid params = do
guardUnconsumed conn
writeCommand (COM_STMT_EXECUTE stid params (makeNullMap params)) os
p <- readPacket is
if isERR p
then decodeFromPacket p >>= throwIO . ERRException
else do
len <- getFromPacket getLenEncInt p
fields <- replicateM len $ (decodeFromPacket <=< readPacket) is
_ <- readPacket is -- eof packet, we don't verify this though
writeIORef consumed False
rows <- Stream.makeInputStream $ do
q <- readPacket is
if | isOK q -> Just <$> getFromPacket (getBinaryRow fields len) q
| isEOF q -> writeIORef consumed True >> return Nothing
| isERR q -> decodeFromPacket q >>= throwIO . ERRException
| otherwise -> throwIO (UnexpectedPacket q)
return (fields, rows)

-- | 'V.Vector' version of 'queryStmt'
--
-- @since 0.5.1.0
--
queryStmtVector :: MySQLConn -> StmtID -> [MySQLValue] -> IO (V.Vector ColumnDef, InputStream (V.Vector MySQLValue))
queryStmtVector conn@(MySQLConn is os _ consumed) stid params = do
queryStmt :: HasCallStack
=> MySQLConn -> StmtID -> [MySQLValue] -> IO (V.Vector ColumnDef, Source (V.Vector MySQLValue))
{-# INLINABLE queryStmt #-}
queryStmt conn@(MySQLConn is os consumed) stid params = do
guardUnconsumed conn
writeCommand (COM_STMT_EXECUTE stid params (makeNullMap params)) os
p <- readPacket is
if isERR p
then decodeFromPacket p >>= throwIO . ERRException
else do
len <- getFromPacket getLenEncInt p
fields <- V.replicateM len $ (decodeFromPacket <=< readPacket) is
_ <- readPacket is -- eof packet, we don't verify this though
writeIORef consumed False
rows <- Stream.makeInputStream $ do
len <- decodeFromPacket decodeLenEncInt p
fields <- readFields len is
_ <- readPacket is -- eof packet, we don't verify this though
writeIORef consumed False
let rows = sourceFromIO $ do
q <- readPacket is
if | isOK q -> Just <$> getFromPacket (getBinaryRowVector fields len) q
if | isOK q -> Just <$!> decodeFromPacket (decodeBinaryRow fields len) q
| isEOF q -> writeIORef consumed True >> return Nothing
| isERR q -> decodeFromPacket q >>= throwIO . ERRException
| otherwise -> throwIO (UnexpectedPacket q)
return (fields, rows)
| otherwise -> throwIO (UnexpectedPacket q callStack)
return (fields, rows)

-- | Run querys inside a transaction, querys will be rolled back if exception arise.
--
Expand Down
Loading