@@ -43,6 +43,7 @@ import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
43
43
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op )
44
44
45
45
import qualified Data.ByteString.Char8 as BS
46
+ import TensorFlow.Session (SessionT )
46
47
47
48
testGradientSimple :: Test
48
49
testGradientSimple = testCase " testGradientSimple" $ do
@@ -290,6 +291,35 @@ testTanhGrad = testCase "testTanhGrad" $ do
290
291
TF. gradients y [x] >>= TF. run
291
292
V. fromList [1 ] @=? dx
292
293
294
+ testExpandDims :: Test
295
+ testExpandDims =
296
+ testCase " testExpandDims" $ do
297
+ ([dx], [s]) <-
298
+ TF. runSession $ do
299
+ (x :: TF. Tensor TF. Value Float ) <- TF. render $ TF. zeros $ TF. Shape [1 , 2 , 3 :: Int64 ]
300
+ let y = TF. expandDims x $ TF. constant (TF. Shape [1 ]) [0 :: Int32 ]
301
+ calculateGradWithShape y x
302
+ V. fromList [1 , 1 , 1 , 1 , 1 , 1 ] @=? dx
303
+ V. fromList [1 , 2 , 3 ] @=? s
304
+
305
+ testReshape :: Test
306
+ testReshape =
307
+ testCase " testReshape" $ do
308
+ ([dx], [s]) <-
309
+ TF. runSession $ do
310
+ (x :: TF. Tensor TF. Value Float ) <- TF. render $ TF. zeros $ TF. Shape [2 , 2 :: Int64 ]
311
+ let y = TF. reshape x $ TF. constant (TF. Shape [2 ]) [1 , 4 :: Int32 ]
312
+ calculateGradWithShape y x
313
+ V. fromList [1 , 1 , 1 , 1 ] @=? dx
314
+ V. fromList [2 , 2 ] @=? s
315
+
316
+ calculateGradWithShape :: TF. Tensor TF. Build Float -> TF. Tensor TF. Value Float -> SessionT IO ([V. Vector Float ], [V. Vector Int32 ])
317
+ calculateGradWithShape y x = do
318
+ gs <- TF. gradients y [x]
319
+ xs <- TF. run gs
320
+ (shapes :: [V. Vector Int32 ]) <- mapM (TF. run . TF. shape) gs
321
+ return (xs, shapes)
322
+
293
323
testFillGrad :: Test
294
324
testFillGrad = testCase " testFillGrad" $ do
295
325
[dx] <- TF. runSession $ do
@@ -436,6 +466,8 @@ main = defaultMain
436
466
, testReluGrad
437
467
, testReluGradGrad
438
468
, testTanhGrad
469
+ , testExpandDims
470
+ , testReshape
439
471
, testFillGrad
440
472
, testTileGrad
441
473
, testTile2DGrad
0 commit comments