Skip to content

Commit 29bad26

Browse files
committed
Implement owned ByteString return with finalizers
1 parent c776f5b commit 29bad26

File tree

6 files changed

+135
-67
lines changed

6 files changed

+135
-67
lines changed

inline-rust.cabal

+1
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ library
2929
Language.Rust.Inline.TH
3030
other-modules: Language.Rust.Inline.Context
3131
Language.Rust.Inline.Context.Prelude
32+
Language.Rust.Inline.Context.ByteString
3233
Language.Rust.Inline.Marshal
3334
Language.Rust.Inline.Parser
3435
Language.Rust.Inline.Pretty

src/Language/Rust/Inline.hs

+38-15
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ module Language.Rust.Inline (
8484
-- externCrate,
8585

8686
import Language.Rust.Inline.Context
87+
import Language.Rust.Inline.Context.ByteString (bytestrings)
8788
import Language.Rust.Inline.Context.Prelude (prelude)
8889
import Language.Rust.Inline.Internal
8990
import Language.Rust.Inline.Marshal
@@ -101,13 +102,17 @@ import Foreign.Marshal.Alloc (alloca, free)
101102
import Foreign.Marshal.Array (newArray, withArrayLen)
102103
import Foreign.Marshal.Unsafe (unsafeLocalState)
103104
import Foreign.Marshal.Utils (new, with)
104-
import Foreign.Ptr (Ptr, freeHaskellFunPtr)
105+
import Foreign.Ptr (FunPtr, Ptr, freeHaskellFunPtr)
105106

106107
import Control.Monad (void)
107108
import Data.List (intercalate)
108109
import Data.Traversable (for)
110+
import Data.Word (Word8)
109111
import System.Random (randomIO)
110112

113+
import qualified Data.ByteString.Unsafe as ByteString
114+
import Foreign.Storable (Storable (..))
115+
111116
{- $overview
112117
113118
This module provides the facility for dropping in bits of Rust code into your
@@ -307,17 +312,17 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
307312

308313
-- Convert the Haskell return type to a marshallable FFI type
309314
(returnFfi, haskRet') <- do
310-
marshalFrom <- ghcMarshallable haskRet
311-
ret <- case marshalFrom of
315+
marshalForm <- ghcMarshallable haskRet
316+
ret <- case marshalForm of
312317
BoxedDirect -> [t|IO $(pure haskRet)|]
313318
BoxedIndirect -> [t|Ptr $(pure haskRet) -> IO ()|]
314319
UnboxedDirect
315320
| isPure -> pure haskRet
316321
| otherwise ->
317322
let retTy = showTy haskRet
318323
in fail ("Cannot put unlifted type ‘" ++ retTy ++ "’ in IO")
319-
ByteString -> undefined
320-
pure (marshalFrom, pure ret)
324+
ByteString -> [t|Ptr (Ptr Word8, Word, FunPtr (Ptr Word8 -> Word -> IO ())) -> IO ()|]
325+
pure (marshalForm, pure ret)
321326

322327
-- Convert the Haskell arguments to marshallable FFI types
323328
(marshalForms, haskArgs') <- fmap unzip $
@@ -341,14 +346,17 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
341346
ptr <- [t|Ptr $(pure haskArg)|]
342347
pure (BoxedIndirect, ptr)
343348
ByteString -> do
344-
rbsT <- [t|Ptr RustByteString|]
349+
rbsT <- [t|Ptr (Ptr Word8, Word)|]
345350
pure (ByteString, rbsT)
346351
_ -> pure (marshalForm, haskArg)
347352

348353
-- Generate the Haskell FFI import declaration and emit it
354+
bsFree <- newName $ "bsFree" ++ show (abs q)
355+
bsFreeSig <- [t|FunPtr (Ptr Word8 -> Word -> IO ()) -> Ptr Word8 -> Word -> IO ()|]
349356
haskSig <- foldr (\l r -> [t|$(pure l) -> $r|]) haskRet' haskArgs'
350357
let ffiImport = ForeignD (ImportF CCall safety qqStrName qqName haskSig)
351-
addTopDecls [ffiImport]
358+
let ffiBsFree = ForeignD (ImportF CCall Safe "dynamic" bsFree bsFreeSig)
359+
addTopDecls [ffiImport, ffiBsFree]
352360

353361
-- Generate the Haskell FFI call
354362
let goArgs ::
@@ -363,7 +371,24 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
363371
-- accumulated arguments. If the return value is not marshallable, we have to
364372
-- 'alloca' some space to put the return value.
365373
goArgs acc []
366-
| returnFfi /= BoxedIndirect = appsE (varE qqName : reverse acc)
374+
| returnFfi == ByteString = do
375+
ret <- newName "ret"
376+
ptr <- newName "ptr"
377+
len <- newName "len"
378+
finalizer <- newName "finalizer"
379+
[e|
380+
alloca
381+
( \($(varP ret)) ->
382+
do
383+
$(appsE (varE qqName : reverse (varE ret : acc)))
384+
($(varP ptr), $(varP len), $(varP finalizer)) <- peek $(varE ret)
385+
ByteString.unsafePackCStringFinalizer
386+
$(varE ptr)
387+
(fromIntegral $(varE len))
388+
($(varE bsFree) $(varE finalizer) $(varE ptr) $(varE len))
389+
)
390+
|]
391+
| byValue returnFfi = appsE (varE qqName : reverse acc)
367392
| otherwise = do
368393
ret <- newName "ret"
369394
[e|
@@ -385,17 +410,15 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
385410
| marshalForm == ByteString -> do
386411
ptr <- newName "ptr"
387412
len <- newName "len"
388-
bs <- newName "bs"
389413
bsp <- newName "bsp"
390414
[e|
391415
withByteString
392416
$(varE argName)
393417
( \($(varP ptr)) ($(varP len)) ->
394-
let $(varP bs) = RustByteString $(varE ptr) $(varE len)
395-
in with $(varE bs) (\($(varP bsp)) -> $(goArgs (varE bsp : acc) args))
418+
with ($(varE ptr), $(varE len)) (\($(varP bsp)) -> $(goArgs (varE bsp : acc) args))
396419
)
397420
|]
398-
| passByValue marshalForm -> goArgs (varE argName : acc) args
421+
| byValue marshalForm -> goArgs (varE argName : acc) args
399422
| otherwise -> do
400423
x <- newName "x"
401424
[e|
@@ -421,7 +444,7 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
421444
mergeArgs t (Just tInter) = (fmap (const mempty) tInter, t)
422445

423446
-- Generate the Rust function.
424-
let retByVal = returnFfi /= BoxedIndirect
447+
let retByVal = byValue returnFfi
425448
(retArg, retTy, ret)
426449
| retByVal =
427450
( []
@@ -441,15 +464,15 @@ processQQ safety isPure (QQParse rustRet rustBody rustNamedArgs) = do
441464
", "
442465
( [ s ++ ": " ++ marshal (renderType t)
443466
| (s, t, v) <- zip3 rustArgNames rustArgs' marshalForms
444-
, let marshal x = if passByValue v then x else "*const " ++ x
467+
, let marshal x = if byValue v then x else "*const " ++ x
445468
]
446469
++ retArg
447470
)
448471
, ") -> " ++ retTy ++ " {"
449472
, unlines
450473
[ " let " ++ s ++ ": " ++ renderType t ++ " = " ++ marshal s ++ ".marshal();"
451474
| (s, t, v) <- zip3 rustArgNames rustConvertedArgs marshalForms
452-
, let marshal x = if passByValue v then x else "unsafe { ::std::ptr::read(" ++ x ++ ") }"
475+
, let marshal x = if byValue v then x else "unsafe { ::std::ptr::read(" ++ x ++ ") }"
453476
]
454477
, " let out: " ++ renderType rustConvertedRet ++ " = (|| {" ++ renderTokens rustBody ++ "})();"
455478
, " " ++ ret

src/Language/Rust/Inline/Context.hs

-45
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,9 @@
55
{-# LANGUAGE QuasiQuotes #-}
66
{-# LANGUAGE TemplateHaskell #-}
77

8-
{- |
9-
Module : Language.Rust.Inline.Context
10-
Description : Defines contexts (rules mapping Rust types to Haskell types)
11-
Copyright : (c) Alec Theriault, 2017
12-
License : BSD-style
13-
Maintainer : ners <ners@gmx.ch>
14-
Stability : experimental
15-
Portability : GHC
16-
-}
178
module Language.Rust.Inline.Context where
189

1910
import Language.Rust.Inline.Pretty (renderType)
20-
import Language.Rust.Inline.TH.Storable (mkStorable)
2111

2212
import Language.Rust.Quote (ty)
2313
import Language.Rust.Syntax (
@@ -41,10 +31,7 @@ import Data.Int (Int16, Int32, Int64, Int8)
4131
import Data.Word (Word16, Word32, Word64, Word8)
4232
import Foreign.C.Types -- pretty much every type here is used
4333
import Foreign.Ptr (FunPtr, Ptr)
44-
import Foreign.Storable
4534

46-
import Data.ByteString (ByteString)
47-
import qualified Data.ByteString as ByteString
4835
import GHC.Exts (
4936
ByteArray#,
5037
Char#,
@@ -403,35 +390,3 @@ functions = do
403390
, " fn marshal(self) -> (" ++ f ++ ") { self }"
404391
, "}"
405392
]
406-
407-
data RustByteString = RustByteString (Ptr Word8) Word
408-
mkStorable [t|Storable RustByteString|]
409-
410-
bytestrings :: Q Context
411-
bytestrings = do
412-
bytestringT <- [t|ByteString|]
413-
pure $ Context ([rule], [rev bytestringT], [rustByteString, impl])
414-
where
415-
rule rty _
416-
| rty == void [ty| &[u8] |] = pure ([t|ByteString|], pure . pure $ void [ty| RustByteString |])
417-
rule _ _ = mempty
418-
419-
rev bytestringT hty _
420-
| hty == bytestringT = pure . pure . void $ [ty| &[u8] |]
421-
rev _ _ _ = mempty
422-
423-
rustByteString =
424-
unlines
425-
[ "#[repr(C)]"
426-
, "pub struct RustByteString(*const u8, usize);"
427-
]
428-
429-
impl =
430-
unlines
431-
[ "impl<'a> MarshalInto<&'a [u8]> for RustByteString {"
432-
, " fn marshal(self) -> &'a [u8] {"
433-
, " let RustByteString(ptr, len) = self;"
434-
, " unsafe { std::slice::from_raw_parts(ptr, len) }"
435-
, " }"
436-
, "}"
437-
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
{-# LANGUAGE OverloadedStrings #-}
2+
{-# LANGUAGE QuasiQuotes #-}
3+
{-# LANGUAGE ScopedTypeVariables #-}
4+
{-# LANGUAGE TemplateHaskellQuotes #-}
5+
{-# OPTIONS_GHC -w #-}
6+
7+
module Language.Rust.Inline.Context.ByteString where
8+
9+
import Language.Rust.Inline.Context
10+
import Language.Rust.Inline.TH
11+
12+
import Language.Rust.Data.Ident (Ident (..), mkIdent)
13+
14+
import Language.Haskell.TH
15+
import Language.Haskell.TH.Syntax
16+
17+
import Language.Rust.Quote (ty)
18+
import Language.Rust.Syntax
19+
20+
import Foreign.Storable
21+
22+
import Control.Monad (join, unless)
23+
import Data.List (intercalate)
24+
import Data.Maybe (fromMaybe)
25+
26+
import Data.ByteString (ByteString)
27+
import qualified Data.ByteString as ByteString
28+
import Data.Functor (void)
29+
import Data.Word (Word8)
30+
import Foreign.Ptr (Ptr)
31+
32+
bytestrings :: Q Context
33+
bytestrings = do
34+
bytestringT <- [t|ByteString|]
35+
pure $ Context ([rule], [rev bytestringT], [rustByteString, impl])
36+
where
37+
rule rty _
38+
| rty == void [ty| &[u8] |] = pure ([t|ByteString|], pure . pure $ void [ty| RustByteString |])
39+
| rty == void [ty| Vec<u8> |] = pure ([t|ByteString|], pure . pure $ void [ty| RustMutByteString |])
40+
rule _ _ = mempty
41+
42+
rev _ _ _ = mempty
43+
44+
rustByteString =
45+
unlines
46+
[ "#[repr(C)]"
47+
, "#[derive(Copy,Clone)]"
48+
, "pub struct RustByteString(*const u8, usize);"
49+
, ""
50+
, "#[repr(C)]"
51+
, "#[derive(Copy, Clone)]"
52+
, "pub struct RustMutByteString(*mut u8, usize, extern \"C\" fn (*mut u8, usize) -> ());"
53+
]
54+
55+
impl =
56+
unlines
57+
[ "impl<'a> MarshalInto<&'a [u8]> for RustByteString {"
58+
, " fn marshal(self) -> &'a [u8] {"
59+
, " let RustByteString(ptr, len) = self;"
60+
, " unsafe { std::slice::from_raw_parts(ptr, len) }"
61+
, " }"
62+
, "}"
63+
, ""
64+
, "impl MarshalInto<RustMutByteString> for Vec<u8> {"
65+
, " fn marshal(self) -> RustMutByteString {"
66+
, " let bytes = Box::leak(self.into_boxed_slice());"
67+
, " let len = bytes.len();"
68+
, ""
69+
, " extern fn free_bytestring(ptr: *mut u8, len: usize) -> () {"
70+
, " let bytes = unsafe { Box::from_raw(std::ptr::slice_from_raw_parts_mut(ptr, len) ) };"
71+
, " drop(bytes)"
72+
, " }"
73+
, " RustMutByteString(bytes.as_mut_ptr(), len, free_bytestring)"
74+
, " }"
75+
, "}"
76+
]

src/Language/Rust/Inline/Marshal.hs

+2-2
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ data MarshalForm
3939
| ByteString
4040
deriving (Eq)
4141

42-
passByValue :: MarshalForm -> Bool
43-
passByValue = (`elem` [UnboxedDirect, BoxedDirect])
42+
byValue :: MarshalForm -> Bool
43+
byValue = (`elem` [UnboxedDirect, BoxedDirect])
4444

4545
-- | Identify which types can be marshalled by the GHC FFI and which types are
4646
-- unlifted. A negative response to the first of these questions doesn't mean

tests/ByteStrings.hs

+18-5
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
1-
{-# LANGUAGE QuasiQuotes, TemplateHaskell #-}
1+
{-# LANGUAGE QuasiQuotes #-}
2+
{-# LANGUAGE TemplateHaskell #-}
3+
24
module ByteStrings where
35

46
import Language.Rust.Inline
57
import Test.Hspec
68

79
import Data.ByteString (ByteString)
810
import qualified Data.ByteString as ByteString
11+
import qualified Data.ByteString.Unsafe as ByteString
912
import Data.String
1013

1114
extendContext basic
@@ -14,10 +17,20 @@ setCrateModule
1417

1518
bytestringSpec :: Spec
1619
bytestringSpec = describe "ByteStrings" $ do
17-
it "can marshal ByteString arguments" $ do
18-
let inputs = ByteString.pack [0, 1, 2, 3]
19-
rustSum = [rust| u8 {
20+
it "can marshal ByteString arguments" $ do
21+
let inputs = ByteString.pack [0, 1, 2, 3]
22+
rustSum =
23+
[rust| u8 {
2024
let inputs = $( inputs: &[u8] );
2125
inputs.iter().sum()
2226
} |]
23-
rustSum `shouldBe` sum (ByteString.unpack inputs)
27+
rustSum `shouldBe` sum (ByteString.unpack inputs)
28+
29+
it "can marshal ByteString return values" $ do
30+
let rustBs =
31+
[rust| Vec<u8> {
32+
vec![0, 1, 2, 3]
33+
} |]
34+
ByteString.pack [0, 1, 2, 3]
35+
`shouldBe` rustBs
36+
ByteString.unsafeFinalize rustBs

0 commit comments

Comments
 (0)