From b3ebcbf32273be00c92ef59fb0582224733d1444 Mon Sep 17 00:00:00 2001 From: Joey Adams Date: Thu, 3 Jan 2013 19:22:19 -0500 Subject: [PATCH] Add execPrint, execWithCallback, and interruptibly --- Database/SQLite3.hs | 100 +++++++++++++++++++++++++++++++++++-- Database/SQLite3/Direct.hs | 79 ++++++++++++++++++++++++++++- direct-sqlite.cabal | 3 +- test/Main.hs | 47 +++++++++++++++++ 4 files changed, 224 insertions(+), 5 deletions(-) diff --git a/Database/SQLite3.hs b/Database/SQLite3.hs index d27eee4..ab69eef 100644 --- a/Database/SQLite3.hs +++ b/Database/SQLite3.hs @@ -1,3 +1,5 @@ +{-# LANGUAGE BangPatterns #-} +{-# LANGUAGE CPP #-} {-# LANGUAGE DeriveDataTypeable #-} {-# LANGUAGE OverloadedStrings #-} module Database.SQLite3 ( @@ -8,6 +10,9 @@ module Database.SQLite3 ( -- * Simple query execution -- | exec, + execPrint, + execWithCallback, + ExecCallback, -- * Statement management prepare, @@ -51,6 +56,7 @@ module Database.SQLite3 ( -- * Interrupting a long-running query interrupt, + interruptibly, -- * Types Database, @@ -99,11 +105,14 @@ import qualified Database.SQLite3.Direct as Direct import Prelude hiding (error) import qualified Data.Text as T +import qualified Data.Text.IO as T import Control.Applicative ((<$>)) -import Control.Exception (Exception, evaluate, throw, throwIO) +import Control.Concurrent +import Control.Exception import Control.Monad (when, zipWithM_) import Data.ByteString (ByteString) import Data.Int (Int64) +import Data.Maybe (fromMaybe) import Data.Text (Text) import Data.Text.Encoding (encodeUtf8, decodeUtf8With) import Data.Text.Encoding.Error (UnicodeException(..), lenientDecode) @@ -152,9 +161,14 @@ instance Show SQLError where instance Exception SQLError +-- | Like 'decodeUtf8', but substitute a custom error message if +-- decoding fails. fromUtf8 :: String -> Utf8 -> IO Text -fromUtf8 desc (Utf8 bs) = - evaluate $ decodeUtf8With (\_ c -> throw (DecodeError desc c)) bs +fromUtf8 desc utf8 = evaluate $ fromUtf8' desc utf8 + +fromUtf8' :: String -> Utf8 -> Text +fromUtf8' desc (Utf8 bs) = + decodeUtf8With (\_ c -> throw (DecodeError desc c)) bs toUtf8 :: Text -> Utf8 toUtf8 = Utf8 . encodeUtf8 @@ -206,12 +220,92 @@ close :: Database -> IO () close db = Direct.close db >>= checkError (DetailDatabase db) "close" +-- | Make it possible to interrupt the given database operation with an +-- asynchronous exception. This only works if the program is compiled with +-- base >= 4.3 and @-threaded@. +-- +-- It works by running the callback in a forked thread. If interrupted, +-- it uses 'interrupt' to try to stop the operation. +interruptibly :: Database -> IO a -> IO a +#if MIN_VERSION_base(4,3,0) +interruptibly db io + | rtsSupportsBoundThreads = + mask $ \restore -> do + mv <- newEmptyMVar + tid <- forkIO $ try' (restore io) >>= putMVar mv + + let interruptAndWait = + -- Don't let a second exception interrupt us. Otherwise, + -- the operation will dangle in the background, which could + -- be really bad if it uses locally-allocated resources. + uninterruptibleMask_ $ do + -- Tell SQLite3 to interrupt the current query. + interrupt db + + -- Interrupt the thread in case it's blocked for some + -- other reason. + -- + -- NOTE: killThread blocks until the exception is delivered. + -- That's fine, since we're going to wait for the thread + -- to finish anyway. + killThread tid + + -- Wait for the forked thread to finish. + _ <- takeMVar mv + return () + + e <- takeMVar mv `onException` interruptAndWait + either throwIO return e + | otherwise = io + where + try' :: IO a -> IO (Either SomeException a) + try' = try +#else +interruptibly _db io = io +#endif + -- | Execute zero or more SQL statements delimited by semicolons. exec :: Database -> Text -> IO () exec db sql = Direct.exec db (toUtf8 sql) >>= checkErrorMsg ("exec " `appendShow` sql) +-- | Like 'exec', but print result rows to 'System.IO.stdout'. +-- +-- This is mainly for convenience when experimenting in GHCi. +-- The output format may change in the future. +execPrint :: Database -> Text -> IO () +execPrint !db !sql = + interruptibly db $ + execWithCallback db sql $ \_count _colnames -> T.putStrLn . showValues + where + -- This mimics sqlite3's default output mode. It displays a NULL and an + -- empty string identically. + showValues = T.intercalate "|" . map (fromMaybe "") + +-- | Like 'exec', but invoke the callback for each result row. +execWithCallback :: Database -> Text -> ExecCallback -> IO () +execWithCallback db sql cb = + Direct.execWithCallback db (toUtf8 sql) cb' + >>= checkErrorMsg ("execWithCallback " `appendShow` sql) + where + -- We want 'names' computed once and shared with every call. + cb' count namesUtf8 = + let names = map fromUtf8'' namesUtf8 + {-# NOINLINE names #-} + in \valuesUtf8 -> cb count names (map (fmap fromUtf8'') valuesUtf8) + + fromUtf8'' = fromUtf8' "Database.SQLite3.execWithCallback: Invalid UTF-8" + +type ExecCallback + = ColumnCount -- ^ Number of columns, which is the number of items in + -- the following lists. This will be the same for + -- every row. + -> [Text] -- ^ List of column names. This will be the same + -- for every row. + -> [Maybe Text] -- ^ List of column values, as returned by 'columnText'. + -> IO () + -- | -- -- Unlike 'exec', 'prepare' only executes the first statement, and ignores diff --git a/Database/SQLite3/Direct.hs b/Database/SQLite3/Direct.hs index c26be68..ce63863 100644 --- a/Database/SQLite3/Direct.hs +++ b/Database/SQLite3/Direct.hs @@ -1,3 +1,4 @@ +{-# LANGUAGE BangPatterns #-} {-# LANGUAGE DeriveDataTypeable #-} -- | -- This API is a slightly lower-level version of "Database.SQLite3". Namely: @@ -11,11 +12,12 @@ module Database.SQLite3.Direct ( open, close, errmsg, - interrupt, -- * Simple query execution -- | exec, + execWithCallback, + ExecCallback, -- * Statement management prepare, @@ -51,6 +53,9 @@ module Database.SQLite3.Direct ( changes, totalChanges, + -- * Interrupting a long-running query + interrupt, + -- * Types Database(..), Statement(..), @@ -74,7 +79,10 @@ import qualified Data.ByteString.Unsafe as BSU import qualified Data.Text as T import qualified Data.Text.Encoding as T import Control.Applicative ((<$>)) +import Control.Exception as E +import Control.Monad (join) import Data.ByteString (ByteString) +import Data.IORef import Data.Monoid import Data.String (IsString(..)) import Data.Text.Encoding.Error (lenientDecode) @@ -116,6 +124,10 @@ packCStringLen :: CString -> CNumBytes -> IO ByteString packCStringLen cstr len = BS.packCStringLen (cstr, fromIntegral len) +packUtf8Array :: IO a -> (Utf8 -> IO a) -> Int -> Ptr CString -> IO [a] +packUtf8Array onNull onUtf8 count base = + peekArray count base >>= mapM (join . packUtf8 onNull onUtf8) + -- | Like 'unsafeUseAsCStringLen', but if the string is empty, -- never pass the callback a null pointer. unsafeUseAsCStringLenNoNull :: ByteString -> (CString -> CNumBytes -> IO a) -> IO a @@ -207,6 +219,71 @@ exec (Database db) (Utf8 sql) = return $ Left (err, msg) Right () -> return $ Right () +-- | Like 'exec', but invoke the callback for each result row. +-- +-- If the callback throws an exception, it will be rethrown by +-- 'execWithCallback'. +execWithCallback :: Database -> Utf8 -> ExecCallback -> IO (Either (Error, Utf8) ()) +execWithCallback (Database db) (Utf8 sql) cb = do + abortReason <- newIORef Nothing :: IO (IORef (Maybe SomeException)) + cbCache <- newIORef Nothing :: IO (IORef (Maybe ([Maybe Utf8] -> IO ()))) + -- Cache the partial application of column count and name, so if the + -- caller wants to convert them to something else, it only has to do + -- the conversions once. + + let getCallback cCount cNames = do + m <- readIORef cbCache + case m of + Nothing -> do + names <- packUtf8Array (fail "execWithCallback: NULL column name") + return + (fromIntegral cCount) cNames + let !cb' = cb (fromFFI cCount) names + writeIORef cbCache $ Just cb' + return cb' + Just cb' -> return cb' + + let onExceptionAbort io = + (io >> return 0) `E.catch` \ex -> do + writeIORef abortReason $ Just ex + return 1 + + let cExecCallback _ctx cCount cValues cNames = + onExceptionAbort $ do + cb' <- getCallback cCount cNames + values <- packUtf8Array (return Nothing) + (return . Just) + (fromIntegral cCount) cValues + cb' values + + BS.useAsCString sql $ \sql' -> + alloca $ \msgPtrOut -> + bracket (mkCExecCallback cExecCallback) freeHaskellFunPtr $ + \pExecCallback -> do + let returnError err = do + msgPtr <- peek msgPtrOut + msg <- packUtf8 (Utf8 BS.empty) id msgPtr + c_sqlite3_free msgPtr + return $ Left (err, msg) + rc <- c_sqlite3_exec db sql' pExecCallback nullPtr msgPtrOut + case toResult () rc of + Left ErrorAbort -> do + m <- readIORef abortReason + case m of + Nothing -> returnError ErrorAbort + Just ex -> throwIO ex + Left err -> returnError err + Right () -> return $ Right () + +type ExecCallback + = ColumnCount -- ^ Number of columns, which is the number of items in + -- the following lists. This will be the same for + -- every row. + -> [Utf8] -- ^ List of column names. This will be the same + -- for every row. + -> [Maybe Utf8] -- ^ List of column values, as returned by 'columnText'. + -> IO () + -- | -- -- If the query contains no SQL statements, this returns diff --git a/direct-sqlite.cabal b/direct-sqlite.cabal index 4a0eeb1..ac0f42a 100644 --- a/direct-sqlite.cabal +++ b/direct-sqlite.cabal @@ -89,7 +89,8 @@ test-suite test default-language: Haskell2010 - default-extensions: NamedFieldPuns + default-extensions: DeriveDataTypeable + , NamedFieldPuns , OverloadedStrings , Rank2Types , RecordWildCards diff --git a/test/Main.hs b/test/Main.hs index 530d042..8d9374b 100644 --- a/test/Main.hs +++ b/test/Main.hs @@ -8,10 +8,12 @@ import Control.Exception import Control.Monad (forM_, liftM3, when) import Data.Text (Text) import Data.Text.Encoding.Error (UnicodeException(..)) +import Data.Typeable import System.Directory import System.Exit (exitFailure) import System.IO import System.IO.Error (isDoesNotExistError, isUserError) +import System.Timeout (timeout) import Test.HUnit import qualified Data.ByteString as B @@ -33,6 +35,7 @@ data TestEnv = regressionTests :: [TestEnv -> Test] regressionTests = [ TestLabel "Exec" . testExec + , TestLabel "ExecCallback" . testExecCallback , TestLabel "Simple" . testSimplest , TestLabel "Prepare" . testPrepare , TestLabel "CloseBusy" . testCloseBusy @@ -105,6 +108,47 @@ testExec TestEnv{..} = TestCase $ do Done <- step stmt return () +data Ex = Ex + deriving (Show, Typeable) + +instance Exception Ex + +testExecCallback :: TestEnv -> Test +testExecCallback TestEnv{..} = TestCase $ + withConn $ \conn -> do + chan <- newChan + let exec' sql = execWithCallback conn sql $ \c n v -> writeChan chan (c, n, v) + exec' "CREATE TABLE foo (a INT, b TEXT); \ + \INSERT INTO foo VALUES (1, 'a'); \ + \INSERT INTO foo VALUES (2, 'b'); \ + \INSERT INTO foo VALUES (3, null); \ + \INSERT INTO foo VALUES (null, 'd'); " + + exec' "SELECT 1, 2, 3" + (3, ["1","2","3"], [Just "1", Just "2", Just "3"]) <- readChan chan + + exec' "SELECT null" + (1, ["null"], [Nothing]) <- readChan chan + + exec' "SELECT * FROM foo" + (2, ["a","b"], [Just "1", Just "a"]) <- readChan chan + (2, ["a","b"], [Just "2", Just "b"]) <- readChan chan + (2, ["a","b"], [Just "3", Nothing ]) <- readChan chan + (2, ["a","b"], [Nothing, Just "d"]) <- readChan chan + + exec' "SELECT * FROM foo WHERE a < 0; SELECT 123" + (1, ["123"], [Just "123"]) <- readChan chan + + exec' "SELECT rowid, f.a, f.b, a || b FROM foo AS f" + (4, ["rowid", "a", "b", "a || b"], [Just "1", Just "1", Just "a", Just "1a"]) <- readChan chan + (4, ["rowid", "a", "b", "a || b"], [Just "2", Just "2", Just "b", Just "2b"]) <- readChan chan + (4, ["rowid", "a", "b", "a || b"], [Just "3", Just "3", Nothing , Nothing ]) <- readChan chan + (4, ["rowid", "a", "b", "a || b"], [Just "4", Nothing , Just "d", Nothing ]) <- readChan chan + + Left Ex <- try $ execWithCallback conn "SELECT 1" $ \_ _ _ -> throwIO Ex + + return () + -- Simplest SELECT testSimplest :: TestEnv -> Test testSimplest TestEnv{..} = TestCase $ do @@ -506,6 +550,9 @@ testInterrupt TestEnv{..} = TestCase $ _ <- forkIO $ threadDelay 100000 >> interrupt conn Left ErrorInterrupt <- Direct.step stmt Left ErrorInterrupt <- Direct.finalize stmt + + Nothing <- timeout 100000 $ interruptibly conn $ exec conn tripleSum + return () where