Skip to content

Commit

Permalink
Add execPrint, execWithCallback, and interruptibly
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyadams committed Jan 4, 2013
1 parent 3b1a86f commit b3ebcbf
Show file tree
Hide file tree
Showing 4 changed files with 224 additions and 5 deletions.
100 changes: 97 additions & 3 deletions Database/SQLite3.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
module Database.SQLite3 (
Expand All @@ -8,6 +10,9 @@ module Database.SQLite3 (
-- * Simple query execution
-- | <http://sqlite.org/c3ref/exec.html>
exec,
execPrint,
execWithCallback,
ExecCallback,

-- * Statement management
prepare,
Expand Down Expand Up @@ -51,6 +56,7 @@ module Database.SQLite3 (

-- * Interrupting a long-running query
interrupt,
interruptibly,

-- * Types
Database,
Expand Down Expand Up @@ -99,11 +105,14 @@ import qualified Database.SQLite3.Direct as Direct

import Prelude hiding (error)
import qualified Data.Text as T
import qualified Data.Text.IO as T
import Control.Applicative ((<$>))
import Control.Exception (Exception, evaluate, throw, throwIO)
import Control.Concurrent
import Control.Exception
import Control.Monad (when, zipWithM_)
import Data.ByteString (ByteString)
import Data.Int (Int64)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Data.Text.Encoding (encodeUtf8, decodeUtf8With)
import Data.Text.Encoding.Error (UnicodeException(..), lenientDecode)
Expand Down Expand Up @@ -152,9 +161,14 @@ instance Show SQLError where

instance Exception SQLError

-- | Like 'decodeUtf8', but substitute a custom error message if
-- decoding fails.
fromUtf8 :: String -> Utf8 -> IO Text
fromUtf8 desc (Utf8 bs) =
evaluate $ decodeUtf8With (\_ c -> throw (DecodeError desc c)) bs
fromUtf8 desc utf8 = evaluate $ fromUtf8' desc utf8

fromUtf8' :: String -> Utf8 -> Text
fromUtf8' desc (Utf8 bs) =
decodeUtf8With (\_ c -> throw (DecodeError desc c)) bs

toUtf8 :: Text -> Utf8
toUtf8 = Utf8 . encodeUtf8
Expand Down Expand Up @@ -206,12 +220,92 @@ close :: Database -> IO ()
close db =
Direct.close db >>= checkError (DetailDatabase db) "close"

-- | Make it possible to interrupt the given database operation with an
-- asynchronous exception. This only works if the program is compiled with
-- base >= 4.3 and @-threaded@.
--
-- It works by running the callback in a forked thread. If interrupted,
-- it uses 'interrupt' to try to stop the operation.
interruptibly :: Database -> IO a -> IO a
#if MIN_VERSION_base(4,3,0)
interruptibly db io
| rtsSupportsBoundThreads =
mask $ \restore -> do
mv <- newEmptyMVar
tid <- forkIO $ try' (restore io) >>= putMVar mv

let interruptAndWait =
-- Don't let a second exception interrupt us. Otherwise,
-- the operation will dangle in the background, which could
-- be really bad if it uses locally-allocated resources.
uninterruptibleMask_ $ do
-- Tell SQLite3 to interrupt the current query.
interrupt db

-- Interrupt the thread in case it's blocked for some
-- other reason.
--
-- NOTE: killThread blocks until the exception is delivered.
-- That's fine, since we're going to wait for the thread
-- to finish anyway.
killThread tid

-- Wait for the forked thread to finish.
_ <- takeMVar mv
return ()

e <- takeMVar mv `onException` interruptAndWait
either throwIO return e
| otherwise = io
where
try' :: IO a -> IO (Either SomeException a)
try' = try
#else
interruptibly _db io = io
#endif

-- | Execute zero or more SQL statements delimited by semicolons.
exec :: Database -> Text -> IO ()
exec db sql =
Direct.exec db (toUtf8 sql)
>>= checkErrorMsg ("exec " `appendShow` sql)

-- | Like 'exec', but print result rows to 'System.IO.stdout'.
--
-- This is mainly for convenience when experimenting in GHCi.
-- The output format may change in the future.
execPrint :: Database -> Text -> IO ()
execPrint !db !sql =
interruptibly db $
execWithCallback db sql $ \_count _colnames -> T.putStrLn . showValues
where
-- This mimics sqlite3's default output mode. It displays a NULL and an
-- empty string identically.
showValues = T.intercalate "|" . map (fromMaybe "")

-- | Like 'exec', but invoke the callback for each result row.
execWithCallback :: Database -> Text -> ExecCallback -> IO ()
execWithCallback db sql cb =
Direct.execWithCallback db (toUtf8 sql) cb'
>>= checkErrorMsg ("execWithCallback " `appendShow` sql)
where
-- We want 'names' computed once and shared with every call.
cb' count namesUtf8 =
let names = map fromUtf8'' namesUtf8
{-# NOINLINE names #-}
in \valuesUtf8 -> cb count names (map (fmap fromUtf8'') valuesUtf8)

fromUtf8'' = fromUtf8' "Database.SQLite3.execWithCallback: Invalid UTF-8"

type ExecCallback
= ColumnCount -- ^ Number of columns, which is the number of items in
-- the following lists. This will be the same for
-- every row.
-> [Text] -- ^ List of column names. This will be the same
-- for every row.
-> [Maybe Text] -- ^ List of column values, as returned by 'columnText'.
-> IO ()

-- | <http://www.sqlite.org/c3ref/prepare.html>
--
-- Unlike 'exec', 'prepare' only executes the first statement, and ignores
Expand Down
79 changes: 78 additions & 1 deletion Database/SQLite3/Direct.hs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE DeriveDataTypeable #-}
-- |
-- This API is a slightly lower-level version of "Database.SQLite3". Namely:
Expand All @@ -11,11 +12,12 @@ module Database.SQLite3.Direct (
open,
close,
errmsg,
interrupt,

-- * Simple query execution
-- | <http://sqlite.org/c3ref/exec.html>
exec,
execWithCallback,
ExecCallback,

-- * Statement management
prepare,
Expand Down Expand Up @@ -51,6 +53,9 @@ module Database.SQLite3.Direct (
changes,
totalChanges,

-- * Interrupting a long-running query
interrupt,

-- * Types
Database(..),
Statement(..),
Expand All @@ -74,7 +79,10 @@ import qualified Data.ByteString.Unsafe as BSU
import qualified Data.Text as T
import qualified Data.Text.Encoding as T
import Control.Applicative ((<$>))
import Control.Exception as E
import Control.Monad (join)
import Data.ByteString (ByteString)
import Data.IORef
import Data.Monoid
import Data.String (IsString(..))
import Data.Text.Encoding.Error (lenientDecode)
Expand Down Expand Up @@ -116,6 +124,10 @@ packCStringLen :: CString -> CNumBytes -> IO ByteString
packCStringLen cstr len =
BS.packCStringLen (cstr, fromIntegral len)

packUtf8Array :: IO a -> (Utf8 -> IO a) -> Int -> Ptr CString -> IO [a]
packUtf8Array onNull onUtf8 count base =
peekArray count base >>= mapM (join . packUtf8 onNull onUtf8)

-- | Like 'unsafeUseAsCStringLen', but if the string is empty,
-- never pass the callback a null pointer.
unsafeUseAsCStringLenNoNull :: ByteString -> (CString -> CNumBytes -> IO a) -> IO a
Expand Down Expand Up @@ -207,6 +219,71 @@ exec (Database db) (Utf8 sql) =
return $ Left (err, msg)
Right () -> return $ Right ()

-- | Like 'exec', but invoke the callback for each result row.
--
-- If the callback throws an exception, it will be rethrown by
-- 'execWithCallback'.
execWithCallback :: Database -> Utf8 -> ExecCallback -> IO (Either (Error, Utf8) ())
execWithCallback (Database db) (Utf8 sql) cb = do
abortReason <- newIORef Nothing :: IO (IORef (Maybe SomeException))
cbCache <- newIORef Nothing :: IO (IORef (Maybe ([Maybe Utf8] -> IO ())))
-- Cache the partial application of column count and name, so if the
-- caller wants to convert them to something else, it only has to do
-- the conversions once.

let getCallback cCount cNames = do
m <- readIORef cbCache
case m of
Nothing -> do
names <- packUtf8Array (fail "execWithCallback: NULL column name")
return
(fromIntegral cCount) cNames
let !cb' = cb (fromFFI cCount) names
writeIORef cbCache $ Just cb'
return cb'
Just cb' -> return cb'

let onExceptionAbort io =
(io >> return 0) `E.catch` \ex -> do
writeIORef abortReason $ Just ex
return 1

let cExecCallback _ctx cCount cValues cNames =
onExceptionAbort $ do
cb' <- getCallback cCount cNames
values <- packUtf8Array (return Nothing)
(return . Just)
(fromIntegral cCount) cValues
cb' values

BS.useAsCString sql $ \sql' ->
alloca $ \msgPtrOut ->
bracket (mkCExecCallback cExecCallback) freeHaskellFunPtr $
\pExecCallback -> do
let returnError err = do
msgPtr <- peek msgPtrOut
msg <- packUtf8 (Utf8 BS.empty) id msgPtr
c_sqlite3_free msgPtr
return $ Left (err, msg)
rc <- c_sqlite3_exec db sql' pExecCallback nullPtr msgPtrOut
case toResult () rc of
Left ErrorAbort -> do
m <- readIORef abortReason
case m of
Nothing -> returnError ErrorAbort
Just ex -> throwIO ex
Left err -> returnError err
Right () -> return $ Right ()

type ExecCallback
= ColumnCount -- ^ Number of columns, which is the number of items in
-- the following lists. This will be the same for
-- every row.
-> [Utf8] -- ^ List of column names. This will be the same
-- for every row.
-> [Maybe Utf8] -- ^ List of column values, as returned by 'columnText'.
-> IO ()

-- | <http://www.sqlite.org/c3ref/prepare.html>
--
-- If the query contains no SQL statements, this returns
Expand Down
3 changes: 2 additions & 1 deletion direct-sqlite.cabal
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ test-suite test

default-language: Haskell2010

default-extensions: NamedFieldPuns
default-extensions: DeriveDataTypeable
, NamedFieldPuns
, OverloadedStrings
, Rank2Types
, RecordWildCards
Expand Down
47 changes: 47 additions & 0 deletions test/Main.hs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@ import Control.Exception
import Control.Monad (forM_, liftM3, when)
import Data.Text (Text)
import Data.Text.Encoding.Error (UnicodeException(..))
import Data.Typeable
import System.Directory
import System.Exit (exitFailure)
import System.IO
import System.IO.Error (isDoesNotExistError, isUserError)
import System.Timeout (timeout)
import Test.HUnit

import qualified Data.ByteString as B
Expand All @@ -33,6 +35,7 @@ data TestEnv =
regressionTests :: [TestEnv -> Test]
regressionTests =
[ TestLabel "Exec" . testExec
, TestLabel "ExecCallback" . testExecCallback
, TestLabel "Simple" . testSimplest
, TestLabel "Prepare" . testPrepare
, TestLabel "CloseBusy" . testCloseBusy
Expand Down Expand Up @@ -105,6 +108,47 @@ testExec TestEnv{..} = TestCase $ do
Done <- step stmt
return ()

data Ex = Ex
deriving (Show, Typeable)

instance Exception Ex

testExecCallback :: TestEnv -> Test
testExecCallback TestEnv{..} = TestCase $
withConn $ \conn -> do
chan <- newChan
let exec' sql = execWithCallback conn sql $ \c n v -> writeChan chan (c, n, v)
exec' "CREATE TABLE foo (a INT, b TEXT); \
\INSERT INTO foo VALUES (1, 'a'); \
\INSERT INTO foo VALUES (2, 'b'); \
\INSERT INTO foo VALUES (3, null); \
\INSERT INTO foo VALUES (null, 'd'); "

exec' "SELECT 1, 2, 3"
(3, ["1","2","3"], [Just "1", Just "2", Just "3"]) <- readChan chan

exec' "SELECT null"
(1, ["null"], [Nothing]) <- readChan chan

exec' "SELECT * FROM foo"
(2, ["a","b"], [Just "1", Just "a"]) <- readChan chan
(2, ["a","b"], [Just "2", Just "b"]) <- readChan chan
(2, ["a","b"], [Just "3", Nothing ]) <- readChan chan
(2, ["a","b"], [Nothing, Just "d"]) <- readChan chan

exec' "SELECT * FROM foo WHERE a < 0; SELECT 123"
(1, ["123"], [Just "123"]) <- readChan chan

exec' "SELECT rowid, f.a, f.b, a || b FROM foo AS f"
(4, ["rowid", "a", "b", "a || b"], [Just "1", Just "1", Just "a", Just "1a"]) <- readChan chan
(4, ["rowid", "a", "b", "a || b"], [Just "2", Just "2", Just "b", Just "2b"]) <- readChan chan
(4, ["rowid", "a", "b", "a || b"], [Just "3", Just "3", Nothing , Nothing ]) <- readChan chan
(4, ["rowid", "a", "b", "a || b"], [Just "4", Nothing , Just "d", Nothing ]) <- readChan chan

Left Ex <- try $ execWithCallback conn "SELECT 1" $ \_ _ _ -> throwIO Ex

return ()

-- Simplest SELECT
testSimplest :: TestEnv -> Test
testSimplest TestEnv{..} = TestCase $ do
Expand Down Expand Up @@ -506,6 +550,9 @@ testInterrupt TestEnv{..} = TestCase $
_ <- forkIO $ threadDelay 100000 >> interrupt conn
Left ErrorInterrupt <- Direct.step stmt
Left ErrorInterrupt <- Direct.finalize stmt

Nothing <- timeout 100000 $ interruptibly conn $ exec conn tripleSum

return ()

where
Expand Down

0 comments on commit b3ebcbf

Please sign in to comment.