@@ -32,9 +32,9 @@ import Control.Monad(forM_, replicateM, zipWithM)
32
32
import Control.Monad.IO.Class (liftIO )
33
33
34
34
import qualified TensorFlow.Core as TF
35
- import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput' , max , maximum , tile , pad , batchToSpaceND , spaceToBatchND , squeeze , sqrt )
35
+ import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput' , max , maximum , tile , pad , batchToSpaceND , spaceToBatchND , squeeze , sqrt , slice , shape )
36
36
import qualified TensorFlow.Gradient as TF
37
- import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable )
37
+ import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable , shape )
38
38
import qualified TensorFlow.Output as TF
39
39
import qualified TensorFlow.Types as TF
40
40
import qualified TensorFlow.Variable as TF
@@ -324,6 +324,7 @@ testPad =
324
324
V. fromList [1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 , 1 ] @=? dx
325
325
V. fromList [2 , 2 , 3 ] @=? s
326
326
327
+
327
328
testSqrt :: Test
328
329
testSqrt = testCase " testSqrt" $ do
329
330
[dx] <- TF. runSession $ do
@@ -332,6 +333,25 @@ testSqrt = testCase "testSqrt" $ do
332
333
TF. gradients y [x] >>= TF. run
333
334
V. fromList [2 ] @=? dx
334
335
336
+ testSlice :: Test
337
+ testSlice =
338
+ testCase " testSlice" $ do
339
+ ([dx], [s]) <-
340
+ TF. runSession $ do
341
+ (x :: TF. Tensor TF. Value Float ) <- TF. render $ TF. zeros $ TF. Shape [2 , 3 , 4 :: Int64 ]
342
+ (z :: TF. Tensor TF. Value Float ) <- TF. render $ TF. zeros $ TF. Shape [1 , 2 , 2 :: Int64 ]
343
+ let y = TF. slice x (TF. constant (TF. Shape [3 ]) [1 , 1 , 1 :: Int32 ]) (TF. shape z)
344
+ calculateGradWithShape y x
345
+ let expected =
346
+ [0 , 0 , 0 , 0 ,
347
+ 0 , 0 , 0 , 0 ,
348
+ 0 , 0 , 0 , 0 ,
349
+ 0 , 0 , 0 , 0 ,
350
+ 0 , 1 , 1 , 0 ,
351
+ 0 , 1 , 1 , 0 ]
352
+ V. fromList expected @=? dx
353
+ V. fromList [2 , 3 , 4 ] @=? s
354
+
335
355
testBatchToSpaceND :: Test
336
356
testBatchToSpaceND =
337
357
testCase " testBatchToSpaceND" $ do
@@ -526,6 +546,7 @@ main = defaultMain
526
546
, testReshape
527
547
, testPad
528
548
, testSqrt
549
+ , testSlice
529
550
, testBatchToSpaceND
530
551
, testSpaceToBatchND
531
552
, testSqueeze
0 commit comments