Skip to content

Commit 5e2bd5a

Browse files
fkm3Bart Schuurmans
authored andcommitted
1 parent d24aae5 commit 5e2bd5a

File tree

4 files changed

+76
-32
lines changed

4 files changed

+76
-32
lines changed

tensorflow/src/TensorFlow/Internal/FFI.hs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ module TensorFlow.Internal.FFI
2626
, setSessionConfig
2727
, setSessionTarget
2828
, getAllOpList
29+
, unsafeTStringToByteString
2930
-- * Internal helper.
3031
, useProtoAsVoidPtrLen
3132
)
3233
where
3334

35+
import Control.Exception (assert)
3436
import Control.Concurrent.Async (Async, async, cancel, waitCatch)
3537
import Control.Concurrent.MVar (MVar, modifyMVarMasked_, newMVar, takeMVar)
3638
import Control.Monad (when)
@@ -61,6 +63,17 @@ import Proto.Tensorflow.Core.Protobuf.Config (ConfigProto)
6163

6264
import qualified TensorFlow.Internal.Raw as Raw
6365

66+
-- Interpret a vector of bytes as a TF_TString struct and copy the pointed
67+
-- to string into a ByteString.
68+
unsafeTStringToByteString :: S.Vector Word8 -> B.ByteString
69+
unsafeTStringToByteString v =
70+
assert (S.length v == Raw.sizeOfTString) $
71+
unsafePerformIO $ S.unsafeWith v $ \tstringPtr -> do
72+
let tstring = Raw.TString (castPtr tstringPtr)
73+
p <- Raw.stringGetDataPointer tstring
74+
n <- Raw.stringGetSize tstring
75+
B.packCStringLen (p, fromIntegral n)
76+
6477
data TensorFlowException = TensorFlowException Raw.Code T.Text
6578
deriving (Show, Eq, Typeable)
6679

tensorflow/src/TensorFlow/Internal/Raw.chs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,23 @@ message :: Status -> IO CString
4444
message = {# call TF_Message as ^ #}
4545

4646

47+
-- TString.
48+
{# pointer *TF_TString as TString newtype #}
49+
50+
sizeOfTString :: Int
51+
sizeOfTString = 24
52+
53+
-- TF_TString_Type::TF_TSTR_OFFSET
54+
tstringOffsetTypeTag :: Word32
55+
tstringOffsetTypeTag = 2
56+
57+
stringGetDataPointer :: TString -> IO CString
58+
stringGetDataPointer = {# call TF_StringGetDataPointer as ^ #}
59+
60+
stringGetSize :: TString -> IO CULong
61+
stringGetSize = {# call TF_StringGetSize as ^ #}
62+
63+
4764
-- Buffer.
4865
data Buffer
4966
{# pointer *TF_Buffer as BufferPtr -> Buffer #}

tensorflow/src/TensorFlow/Types.hs

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ module TensorFlow.Types
6464
, AllTensorTypes
6565
) where
6666

67+
import Data.Bits (shiftL, (.|.))
6768
import Data.ProtoLens.Message(defMessage)
6869
import Data.Functor.Identity (Identity(..))
6970
import Data.Complex (Complex)
@@ -86,6 +87,7 @@ import qualified Data.ByteString.Builder as Builder
8687
import qualified Data.ByteString.Lazy as L
8788
import qualified Data.Vector as V
8889
import qualified Data.Vector.Storable as S
90+
import Data.Vector.Split (chunksOf)
8991
import Proto.Tensorflow.Core.Framework.AttrValue
9092
( AttrValue
9193
, AttrValue'ListValue
@@ -127,6 +129,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape_Fields
127129
import Proto.Tensorflow.Core.Framework.Types (DataType(..))
128130

129131
import TensorFlow.Internal.VarInt (getVarInt, putVarInt)
132+
import qualified TensorFlow.Internal.Raw as Raw
130133
import qualified TensorFlow.Internal.FFI as FFI
131134

132135
type ResourceHandle = ResourceHandleProto
@@ -317,50 +320,60 @@ instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
317320
encodeTensorData = error "TODO (Complex Double)"
318321

319322
instance {-# OVERLAPPING #-} TensorDataType V.Vector ByteString where
320-
-- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
321-
-- table offsets for each element :: [Word64]
322-
-- at each element offset:
323-
-- string length :: VarInt64
324-
-- string data :: [Word8]
323+
-- Strings can be encoded in various ways, see [0] for an overview.
324+
--
325+
-- The data starts with an array of TF_TString structs (24 bytes each), one
326+
-- for each element in the tensor. In some cases, the actual string
327+
-- contents are inlined in the TF_TString, in some cases they are in the
328+
-- heap, in some cases they are appended to the end of the data.
329+
--
330+
-- When decoding, we delegate most of those details to the TString C API.
331+
-- However, when encoding, the TString C API is prone to memory leaks given
332+
-- the current design of tensorflow-haskell, so, instead we manually encode
333+
-- all the strings in the "offset" format, where none of the string data is
334+
-- stored in separate heap objects and so no destructor hook is necessary.
335+
--
336+
-- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
325337
decodeTensorData tensorData =
326-
either (\err -> error $ "Malformed TF_STRING tensor; " ++ err) id $
327-
if expected /= count
328-
then Left $ "decodeTensorData for ByteString count mismatch " ++
329-
show (expected, count)
330-
else V.mapM decodeString (S.convert offsets)
338+
if S.length bytes < minBytes
339+
then error $ "Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++
340+
show (S.length bytes) ++ ", need at least " ++ show minBytes
341+
else V.fromList $ map FFI.unsafeTStringToByteString (take numElements (chunksOf 24 bytes))
331342
where
332-
expected = S.length offsets
333-
count = fromIntegral $ product $ FFI.tensorDataDimensions
334-
$ unTensorData tensorData
335343
bytes = FFI.tensorDataBytes $ unTensorData tensorData
336-
offsets = S.take count $ S.unsafeCast bytes :: S.Vector Word64
337-
dataBytes = B.pack $ S.toList $ S.drop (count * 8) bytes
338-
decodeString :: Word64 -> Either String ByteString
339-
decodeString offset =
340-
let stringDataStart = B.drop (fromIntegral offset) dataBytes
341-
in Atto.eitherResult $ Atto.parse stringParser stringDataStart
342-
stringParser :: Atto.Parser ByteString
343-
stringParser = getVarInt >>= Atto.take . fromIntegral
344+
numElements = fromIntegral $ product $ FFI.tensorDataDimensions $ unTensorData tensorData
345+
minBytes = Raw.sizeOfTString * numElements
344346
encodeTensorData (Shape xs) vec =
345347
TensorData $ FFI.TensorData xs dt byteVector
346348
where
347349
dt = tensorType (undefined :: ByteString)
350+
tableSize = fromIntegral $ Raw.sizeOfTString * (V.length vec)
348351
-- Add a string to an offset table and data blob.
349-
addString :: (Builder, Builder, Word64)
352+
addString :: (Builder, Builder, Word32, Word32)
350353
-> ByteString
351-
-> (Builder, Builder, Word64)
352-
addString (table, strings, offset) str =
353-
( table <> Builder.word64LE offset
354-
, strings <> lengthBytes <> Builder.byteString str
355-
, offset + lengthBytesLen + strLen
354+
-> (Builder, Builder, Word32, Word32)
355+
addString (table, strings, tableOffset, stringsOffset) str =
356+
( table <> Builder.word32LE sizeField
357+
<> Builder.word32LE offsetField
358+
<> Builder.word32LE capacityField
359+
<> Builder.word32LE 0
360+
<> Builder.word32LE 0
361+
<> Builder.word32LE 0
362+
, strings <> Builder.byteString str
363+
, tableOffset + fromIntegral Raw.sizeOfTString
364+
, stringsOffset + strLen
356365
)
357366
where
358-
strLen = fromIntegral $ B.length str
359-
lengthBytes = putVarInt $ fromIntegral $ B.length str
360-
lengthBytesLen =
361-
fromIntegral $ L.length $ Builder.toLazyByteString lengthBytes
367+
strLen :: Word32 = fromIntegral $ B.length str
368+
-- TF_TString.size includes a union tag in the first two bits.
369+
sizeField :: Word32 = (shiftL strLen 2) .|. Raw.tstringOffsetTypeTag
370+
-- offset is relative to the start of the TF_TString instance, so
371+
-- we add the remaining distance to the end of the table to the
372+
-- offset from the start of the string data.
373+
offsetField :: Word32 = tableSize - tableOffset + stringsOffset
374+
capacityField :: Word32 = strLen
362375
-- Encode all strings.
363-
(table', strings', _) = V.foldl' addString (mempty, mempty, 0) vec
376+
(table', strings', _, _) = V.foldl' addString (mempty, mempty, 0, 0) vec
364377
-- Concat offset table with data.
365378
bytes = table' <> strings'
366379
-- Convert to Vector Word8.

tensorflow/tensorflow.cabal

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ library
5656
, temporary
5757
, transformers
5858
, vector
59+
, vector-split
5960
extra-libraries: tensorflow
6061
default-language: Haskell2010
6162
include-dirs: .

0 commit comments

Comments
 (0)