Skip to content

Commit 199f1c7

Browse files
committed
1 parent b21dee9 commit 199f1c7

File tree

4 files changed

+76
-34
lines changed

4 files changed

+76
-34
lines changed

tensorflow/src/TensorFlow/Internal/FFI.hs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ module TensorFlow.Internal.FFI
2626
, setSessionConfig
2727
, setSessionTarget
2828
, getAllOpList
29+
, unsafeTStringToByteString
2930
-- * Internal helper.
3031
, useProtoAsVoidPtrLen
3132
)
3233
where
3334

35+
import Control.Exception (assert)
3436
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
3537
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
3638
import Control.Monad (when)
@@ -61,6 +63,17 @@ import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
6163

6264
import qualified TensorFlow.Internal.Raw as Raw
6365

66+
-- Interpret a vector of bytes as a TF_TString struct and copy the pointed
67+
-- to string into a ByteString.
68+
unsafeTStringToByteString :: S.Vector Word8 -> B.ByteString
69+
unsafeTStringToByteString v =
70+
assert (S.length v == Raw.sizeOfTString) $
71+
unsafePerformIO $ S.unsafeWith v $ \tstringPtr -> do
72+
let tstring = Raw.TString (castPtr tstringPtr)
73+
p <- Raw.stringGetDataPointer tstring
74+
n <- Raw.stringGetSize tstring
75+
B.packCStringLen (p, fromIntegral n)
76+
6477
data TensorFlowException = TensorFlowException Raw.Code T.Text
6578
deriving (Show, Eq, Typeable)
6679

tensorflow/src/TensorFlow/Internal/Raw.chs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ message :: Status -> IO CString
4444
message = {# call TF_Message as ^ #}
4545

4646

47+
-- TString.
48+
{# pointer *TF_TString as TString newtype #}
49+
50+
sizeOfTString :: Int
51+
sizeOfTString = 24
52+
53+
-- TF_TString_Type::TF_TSTR_OFFSET
54+
tstringOffsetTypeTag :: Word32
55+
tstringOffsetTypeTag = 2
56+
57+
stringGetDataPointer :: TString -> IO CString
58+
stringGetDataPointer = {# call TF_StringGetDataPointer as ^ #}
59+
60+
stringGetSize :: TString -> IO CULong
61+
stringGetSize = {# call TF_StringGetSize as ^ #}
62+
63+
4764
-- Buffer.
4865
data Buffer
4966
{# pointer *TF_Buffer as BufferPtr -> Buffer #}

tensorflow/src/TensorFlow/Types.hs

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ module TensorFlow.Types
6464
, AllTensorTypes
6565
) where
6666

67+
import Data.Bits (shiftL, (.|.))
6768
import Data.ProtoLens.Message(defMessage)
6869
import Data.Functor.Identity (Identity(..))
6970
import Data.Complex (Complex)
@@ -78,14 +79,14 @@ import GHC.Exts (Constraint, IsList(..))
7879
import Lens.Family2 (Lens', view, (&), (.~), (^..), under)
7980
import Lens.Family2.Unchecked (adapter)
8081
import Text.Printf (printf)
81-
import qualified Data.Attoparsec.ByteString as Atto
8282
import Data.ByteString (ByteString)
8383
import qualified Data.ByteString as B
8484
import Data.ByteString.Builder (Builder)
8585
import qualified Data.ByteString.Builder as Builder
8686
import qualified Data.ByteString.Lazy as L
8787
import qualified Data.Vector as V
8888
import qualified Data.Vector.Storable as S
89+
import Data.Vector.Split (chunksOf)
8990
import Proto.Tensorflow.Core.Framework.AttrValue
9091
( AttrValue
9192
, AttrValue'ListValue
@@ -126,7 +127,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape_Fields
126127
)
127128
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
128129

129-
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
130+
import qualified TensorFlow.Internal.Raw as Raw
130131
import qualified TensorFlow.Internal.FFI as FFI
131132

132133
type ResourceHandle = ResourceHandleProto
@@ -317,50 +318,60 @@ instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
317318
encodeTensorData = error "TODO (Complex Double)"
318319

319320
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
320-
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
321-
-- table offsets for each element :: [Word64]
322-
-- at each element offset:
323-
-- string length :: VarInt64
324-
-- string data :: [Word8]
321+
-- Strings can be encoded in various ways, see [0] for an overview.
322+
--
323+
-- The data starts with an array of TF_TString structs (24 bytes each), one
324+
-- for each element in the tensor. In some cases, the actual string
325+
-- contents are inlined in the TF_TString, in some cases they are in the
326+
-- heap, in some cases they are appended to the end of the data.
327+
--
328+
-- When decoding, we delegate most of those details to the TString C API.
329+
-- However, when encoding, the TString C API is prone to memory leaks given
330+
-- the current design of tensorflow-haskell, so, instead we manually encode
331+
-- all the strings in the "offset" format, where none of the string data is
332+
-- stored in separate heap objects and so no destructor hook is necessary.
333+
--
334+
-- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
325335
decodeTensorData tensorData =
326-
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
327-
if expected /= count
328-
then Left $ "decodeTensorData for ByteString count mismatch " ++
329-
show (expected, count)
330-
else V.mapM decodeString (S.convert offsets)
336+
if S.length bytes < minBytes
337+
then error $ "Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++
338+
show (S.length bytes) ++ ", need at least " ++ show minBytes
339+
else V.fromList $ map FFI.unsafeTStringToByteString (take numElements (chunksOf 24 bytes))
331340
where
332-
expected = S.length offsets
333-
count = fromIntegral $ product $ FFI.tensorDataDimensions
334-
$ unTensorData tensorData
335341
bytes = FFI.tensorDataBytes $ unTensorData tensorData
336-
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64
337-
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes
338-
decodeString :: Word64 -> Either String ByteString
339-
decodeString offset =
340-
let stringDataStart = B.drop (fromIntegral offset) dataBytes
341-
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
342-
stringParser :: Atto.Parser ByteString
343-
stringParser = getVarInt >>= Atto.take . fromIntegral
342+
numElements = fromIntegral $ product $ FFI.tensorDataDimensions $ unTensorData tensorData
343+
minBytes = Raw.sizeOfTString * numElements
344344
encodeTensorData (Shape xs) vec =
345345
TensorData $ FFI.TensorData xs dt byteVector
346346
where
347347
dt = tensorType (undefined :: ByteString)
348+
tableSize = fromIntegral $ Raw.sizeOfTString * (V.length vec)
348349
-- Add a string to an offset table and data blob.
349-
addString :: (Builder, Builder, Word64)
350+
addString :: (Builder, Builder, Word32, Word32)
350351
-> ByteString
351-
-> (Builder, Builder, Word64)
352-
addString (table, strings, offset) str =
353-
( table <> Builder.word64LE offset
354-
, strings <> lengthBytes <> Builder.byteString str
355-
, offset + lengthBytesLen + strLen
352+
-> (Builder, Builder, Word32, Word32)
353+
addString (table, strings, tableOffset, stringsOffset) str =
354+
( table <> Builder.word32LE sizeField
355+
<> Builder.word32LE offsetField
356+
<> Builder.word32LE capacityField
357+
<> Builder.word32LE 0
358+
<> Builder.word32LE 0
359+
<> Builder.word32LE 0
360+
, strings <> Builder.byteString str
361+
, tableOffset + fromIntegral Raw.sizeOfTString
362+
, stringsOffset + strLen
356363
)
357364
where
358-
strLen = fromIntegral $ B.length str
359-
lengthBytes = putVarInt $ fromIntegral $ B.length str
360-
lengthBytesLen =
361-
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes
365+
strLen :: Word32 = fromIntegral $ B.length str
366+
-- TF_TString.size includes a union tag in the first two bits.
367+
sizeField :: Word32 = (shiftL strLen 2) .|. Raw.tstringOffsetTypeTag
368+
-- offset is relative to the start of the TF_TString instance, so
369+
-- we add the remaining distance to the end of the table to the
370+
-- offset from the start of the string data.
371+
offsetField :: Word32 = tableSize - tableOffset + stringsOffset
372+
capacityField :: Word32 = strLen
362373
-- Encode all strings.
363-
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec
374+
(table', strings', _, _) = V.foldl' addString (mempty, mempty, 0, 0) vec
364375
-- Concat offset table with data.
365376
bytes = table' <> strings'
366377
-- Convert to Vector Word8.

tensorflow/tensorflow.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ library
5656
, temporary
5757
, transformers
5858
, vector
59+
, vector-split
5960
extra-libraries: tensorflow
6061
default-language: Haskell2010
6162
include-dirs: .

0 commit comments

Comments
 (0)