@@ -64,6 +64,7 @@ module TensorFlow.Types
64
64
, AllTensorTypes
65
65
) where
66
66
67
+ import Data.Bits (shiftL , (.|.) )
67
68
import Data.ProtoLens.Message (defMessage )
68
69
import Data.Functor.Identity (Identity (.. ))
69
70
import Data.Complex (Complex )
@@ -86,6 +87,7 @@ import qualified Data.ByteString.Builder as Builder
86
87
import qualified Data.ByteString.Lazy as L
87
88
import qualified Data.Vector as V
88
89
import qualified Data.Vector.Storable as S
90
+ import Data.Vector.Split (chunksOf )
89
91
import Proto.Tensorflow.Core.Framework.AttrValue
90
92
( AttrValue
91
93
, AttrValue'ListValue
@@ -127,6 +129,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape_Fields
127
129
import Proto.Tensorflow.Core.Framework.Types (DataType (.. ))
128
130
129
131
import TensorFlow.Internal.VarInt (getVarInt , putVarInt )
132
+ import qualified TensorFlow.Internal.Raw as Raw
130
133
import qualified TensorFlow.Internal.FFI as FFI
131
134
132
135
type ResourceHandle = ResourceHandleProto
@@ -317,50 +320,60 @@ instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
317
320
encodeTensorData = error " TODO (Complex Double)"
318
321
319
322
instance {-# OVERLAPPING #-} TensorDataType V. Vector ByteString where
320
- -- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
321
- -- table offsets for each element :: [Word64]
322
- -- at each element offset:
323
- -- string length :: VarInt64
324
- -- string data :: [Word8]
323
+ -- Strings can be encoded in various ways, see [0] for an overview.
324
+ --
325
+ -- The data starts with an array of TF_TString structs (24 bytes each), one
326
+ -- for each element in the tensor. In some cases, the actual string
327
+ -- contents are inlined in the TF_TString, in some cases they are in the
328
+ -- heap, in some cases they are appended to the end of the data.
329
+ --
330
+ -- When decoding, we delegate most of those details to the TString C API.
331
+ -- However, when encoding, the TString C API is prone to memory leaks given
332
+ -- the current design of tensorflow-haskell, so, instead we manually encode
333
+ -- all the strings in the "offset" format, where none of the string data is
334
+ -- stored in separate heap objects and so no destructor hook is necessary.
335
+ --
336
+ -- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
325
337
decodeTensorData tensorData =
326
- either (\ err -> error $ " Malformed TF_STRING tensor; " ++ err) id $
327
- if expected /= count
328
- then Left $ " decodeTensorData for ByteString count mismatch " ++
329
- show (expected, count)
330
- else V. mapM decodeString (S. convert offsets)
338
+ if S. length bytes < minBytes
339
+ then error $ " Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++
340
+ show (S. length bytes) ++ " , need at least " ++ show minBytes
341
+ else V. fromList $ map FFI. unsafeTStringToByteString (take numElements (chunksOf 24 bytes))
331
342
where
332
- expected = S. length offsets
333
- count = fromIntegral $ product $ FFI. tensorDataDimensions
334
- $ unTensorData tensorData
335
343
bytes = FFI. tensorDataBytes $ unTensorData tensorData
336
- offsets = S. take count $ S. unsafeCast bytes :: S. Vector Word64
337
- dataBytes = B. pack $ S. toList $ S. drop (count * 8 ) bytes
338
- decodeString :: Word64 -> Either String ByteString
339
- decodeString offset =
340
- let stringDataStart = B. drop (fromIntegral offset) dataBytes
341
- in Atto. eitherResult $ Atto. parse stringParser stringDataStart
342
- stringParser :: Atto. Parser ByteString
343
- stringParser = getVarInt >>= Atto. take . fromIntegral
344
+ numElements = fromIntegral $ product $ FFI. tensorDataDimensions $ unTensorData tensorData
345
+ minBytes = Raw. sizeOfTString * numElements
344
346
encodeTensorData (Shape xs) vec =
345
347
TensorData $ FFI. TensorData xs dt byteVector
346
348
where
347
349
dt = tensorType (undefined :: ByteString )
350
+ tableSize = fromIntegral $ Raw. sizeOfTString * (V. length vec)
348
351
-- Add a string to an offset table and data blob.
349
- addString :: (Builder , Builder , Word64 )
352
+ addString :: (Builder , Builder , Word32 , Word32 )
350
353
-> ByteString
351
- -> (Builder , Builder , Word64 )
352
- addString (table, strings, offset) str =
353
- ( table <> Builder. word64LE offset
354
- , strings <> lengthBytes <> Builder. byteString str
355
- , offset + lengthBytesLen + strLen
354
+ -> (Builder , Builder , Word32 , Word32 )
355
+ addString (table, strings, tableOffset, stringsOffset) str =
356
+ ( table <> Builder. word32LE sizeField
357
+ <> Builder. word32LE offsetField
358
+ <> Builder. word32LE capacityField
359
+ <> Builder. word32LE 0
360
+ <> Builder. word32LE 0
361
+ <> Builder. word32LE 0
362
+ , strings <> Builder. byteString str
363
+ , tableOffset + fromIntegral Raw. sizeOfTString
364
+ , stringsOffset + strLen
356
365
)
357
366
where
358
- strLen = fromIntegral $ B. length str
359
- lengthBytes = putVarInt $ fromIntegral $ B. length str
360
- lengthBytesLen =
361
- fromIntegral $ L. length $ Builder. toLazyByteString lengthBytes
367
+ strLen :: Word32 = fromIntegral $ B. length str
368
+ -- TF_TString.size includes a union tag in the first two bits.
369
+ sizeField :: Word32 = (shiftL strLen 2 ) .|. Raw. tstringOffsetTypeTag
370
+ -- offset is relative to the start of the TF_TString instance, so
371
+ -- we add the remaining distance to the end of the table to the
372
+ -- offset from the start of the string data.
373
+ offsetField :: Word32 = tableSize - tableOffset + stringsOffset
374
+ capacityField :: Word32 = strLen
362
375
-- Encode all strings.
363
- (table', strings', _) = V. foldl' addString (mempty , mempty , 0 ) vec
376
+ (table', strings', _, _ ) = V. foldl' addString (mempty , mempty , 0 , 0 ) vec
364
377
-- Concat offset table with data.
365
378
bytes = table' <> strings'
366
379
-- Convert to Vector Word8.
0 commit comments