@@ -64,6 +64,7 @@ module TensorFlow.Types
64
64
, AllTensorTypes
65
65
) where
66
66
67
+ import Data.Bits (shiftL , (.|.) )
67
68
import Data.ProtoLens.Message (defMessage )
68
69
import Data.Functor.Identity (Identity (.. ))
69
70
import Data.Complex (Complex )
@@ -78,14 +79,14 @@ import GHC.Exts (Constraint, IsList(..))
78
79
import Lens.Family2 (Lens' , view , (&) , (.~) , (^..) , under )
79
80
import Lens.Family2.Unchecked (adapter )
80
81
import Text.Printf (printf )
81
- import qualified Data.Attoparsec.ByteString as Atto
82
82
import Data.ByteString (ByteString )
83
83
import qualified Data.ByteString as B
84
84
import Data.ByteString.Builder (Builder )
85
85
import qualified Data.ByteString.Builder as Builder
86
86
import qualified Data.ByteString.Lazy as L
87
87
import qualified Data.Vector as V
88
88
import qualified Data.Vector.Storable as S
89
+ import Data.Vector.Split (chunksOf )
89
90
import Proto.Tensorflow.Core.Framework.AttrValue
90
91
( AttrValue
91
92
, AttrValue'ListValue
@@ -126,7 +127,7 @@ import Proto.Tensorflow.Core.Framework.TensorShape_Fields
126
127
)
127
128
import Proto.Tensorflow.Core.Framework.Types (DataType (.. ))
128
129
129
- import TensorFlow.Internal.VarInt ( getVarInt , putVarInt )
130
+ import qualified TensorFlow.Internal.Raw as Raw
130
131
import qualified TensorFlow.Internal.FFI as FFI
131
132
132
133
type ResourceHandle = ResourceHandleProto
@@ -317,50 +318,60 @@ instance {-# OVERLAPPING #-} TensorDataType V.Vector (Complex Double) where
317
318
encodeTensorData = error " TODO (Complex Double)"
318
319
319
320
instance {-# OVERLAPPING #-} TensorDataType V. Vector ByteString where
320
- -- Encoded data layout (described in third_party/tensorflow/c/c_api.h):
321
- -- table offsets for each element :: [Word64]
322
- -- at each element offset:
323
- -- string length :: VarInt64
324
- -- string data :: [Word8]
321
+ -- Strings can be encoded in various ways, see [0] for an overview.
322
+ --
323
+ -- The data starts with an array of TF_TString structs (24 bytes each), one
324
+ -- for each element in the tensor. In some cases, the actual string
325
+ -- contents are inlined in the TF_TString, in some cases they are in the
326
+ -- heap, in some cases they are appended to the end of the data.
327
+ --
328
+ -- When decoding, we delegate most of those details to the TString C API.
329
+ -- However, when encoding, the TString C API is prone to memory leaks given
330
+ -- the current design of tensorflow-haskell, so, instead we manually encode
331
+ -- all the strings in the "offset" format, where none of the string data is
332
+ -- stored in separate heap objects and so no destructor hook is necessary.
333
+ --
334
+ -- [0] https://github.com/tensorflow/community/blob/master/rfcs/20190411-string-unification.md
325
335
decodeTensorData tensorData =
326
- either (\ err -> error $ " Malformed TF_STRING tensor; " ++ err) id $
327
- if expected /= count
328
- then Left $ " decodeTensorData for ByteString count mismatch " ++
329
- show (expected, count)
330
- else V. mapM decodeString (S. convert offsets)
336
+ if S. length bytes < minBytes
337
+ then error $ " Malformed TF_STRING tensor; decodeTensorData for ByteString with too few bytes, got " ++
338
+ show (S. length bytes) ++ " , need at least " ++ show minBytes
339
+ else V. fromList $ map FFI. unsafeTStringToByteString (take numElements (chunksOf 24 bytes))
331
340
where
332
- expected = S. length offsets
333
- count = fromIntegral $ product $ FFI. tensorDataDimensions
334
- $ unTensorData tensorData
335
341
bytes = FFI. tensorDataBytes $ unTensorData tensorData
336
- offsets = S. take count $ S. unsafeCast bytes :: S. Vector Word64
337
- dataBytes = B. pack $ S. toList $ S. drop (count * 8 ) bytes
338
- decodeString :: Word64 -> Either String ByteString
339
- decodeString offset =
340
- let stringDataStart = B. drop (fromIntegral offset) dataBytes
341
- in Atto. eitherResult $ Atto. parse stringParser stringDataStart
342
- stringParser :: Atto. Parser ByteString
343
- stringParser = getVarInt >>= Atto. take . fromIntegral
342
+ numElements = fromIntegral $ product $ FFI. tensorDataDimensions $ unTensorData tensorData
343
+ minBytes = Raw. sizeOfTString * numElements
344
344
encodeTensorData (Shape xs) vec =
345
345
TensorData $ FFI. TensorData xs dt byteVector
346
346
where
347
347
dt = tensorType (undefined :: ByteString )
348
+ tableSize = fromIntegral $ Raw. sizeOfTString * (V. length vec)
348
349
-- Add a string to an offset table and data blob.
349
- addString :: (Builder , Builder , Word64 )
350
+ addString :: (Builder , Builder , Word32 , Word32 )
350
351
-> ByteString
351
- -> (Builder , Builder , Word64 )
352
- addString (table, strings, offset) str =
353
- ( table <> Builder. word64LE offset
354
- , strings <> lengthBytes <> Builder. byteString str
355
- , offset + lengthBytesLen + strLen
352
+ -> (Builder , Builder , Word32 , Word32 )
353
+ addString (table, strings, tableOffset, stringsOffset) str =
354
+ ( table <> Builder. word32LE sizeField
355
+ <> Builder. word32LE offsetField
356
+ <> Builder. word32LE capacityField
357
+ <> Builder. word32LE 0
358
+ <> Builder. word32LE 0
359
+ <> Builder. word32LE 0
360
+ , strings <> Builder. byteString str
361
+ , tableOffset + fromIntegral Raw. sizeOfTString
362
+ , stringsOffset + strLen
356
363
)
357
364
where
358
- strLen = fromIntegral $ B. length str
359
- lengthBytes = putVarInt $ fromIntegral $ B. length str
360
- lengthBytesLen =
361
- fromIntegral $ L. length $ Builder. toLazyByteString lengthBytes
365
+ strLen :: Word32 = fromIntegral $ B. length str
366
+ -- TF_TString.size includes a union tag in the first two bits.
367
+ sizeField :: Word32 = (shiftL strLen 2 ) .|. Raw. tstringOffsetTypeTag
368
+ -- offset is relative to the start of the TF_TString instance, so
369
+ -- we add the remaining distance to the end of the table to the
370
+ -- offset from the start of the string data.
371
+ offsetField :: Word32 = tableSize - tableOffset + stringsOffset
372
+ capacityField :: Word32 = strLen
362
373
-- Encode all strings.
363
- (table', strings', _) = V. foldl' addString (mempty , mempty , 0 ) vec
374
+ (table', strings', _, _ ) = V. foldl' addString (mempty , mempty , 0 , 0 ) vec
364
375
-- Concat offset table with data.
365
376
bytes = table' <> strings'
366
377
-- Convert to Vector Word8.
0 commit comments