Skip to content

Commit 3cfd96e

Browse files
erikaborfkm3
authored andcommitted
Add gradient for slice function (#234)
1 parent 666dce9 commit 3cfd96e

File tree

2 files changed

+45
-2
lines changed

2 files changed

+45
-2
lines changed

tensorflow-ops/src/TensorFlow/Gradient.hs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,27 @@ opGrad "Pad" _ [toT -> x, toT -> padPattern] [dz] =
710710
gradientSliceBegin = CoreOps.reshape padPatternSliced rankx
711711
gradientSliceSize = shape (x :: Tensor Build Float)
712712

713+
-- Gradient for Slice
714+
-- Create an Nx2 padding where N is the rank of (grad of) Slice and the first
715+
-- column represents how many zeros are to be prepended for each dimension, and the second
716+
-- column indicates how many zeros are appended.
717+
-- The number of zeros to prepend is the shape of the beginvec.
718+
-- The number of zeros to append is the shape of the inputvec
719+
-- elementwise-subtracted by both the beginvec and sizevec.
720+
-- Some more reshaping is needed to assemble this tensor with the
721+
-- right dimensions.
722+
opGrad "Slice" _ [toT -> inputvec, toT -> beginvec, _] [dz] =
723+
[Just $ CoreOps.pad dz paddings, Nothing, Nothing]
724+
where
725+
v1 = vector [1 :: Int32]
726+
inputRank' = CoreOps.rank (inputvec :: Tensor Build Float)
727+
-- For some reason inputRank' has an empty shape
728+
inputRank = CoreOps.reshape inputRank' v1
729+
padShape = CoreOps.concat 0 [inputRank, v1]
730+
beforePad = CoreOps.reshape beginvec padShape
731+
afterPad = CoreOps.reshape (shape inputvec - shape dz - beginvec) padShape
732+
paddings = CoreOps.concat 1 [beforePad, afterPad]
733+
713734
-- TODO: This could be either Int32 or Int64.
714735
opGrad "BatchToSpaceND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> crops] [dz] =
715736
[Just $ CoreOps.spaceToBatchND dz blockShape crops, Nothing, Nothing]
@@ -862,6 +883,7 @@ numOutputs o =
862883
"Reshape" -> 1
863884
"Select" -> 1
864885
"Size" -> 1
886+
"Slice" -> 1
865887
"SoftmaxCrossEntropyWithLogits" -> 2
866888
"SpaceToBatchND" -> 1
867889
"SparseSegmentSum" -> 1

tensorflow-ops/tests/GradientTest.hs

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,9 @@ import Control.Monad(forM_, replicateM, zipWithM)
3232
import Control.Monad.IO.Class (liftIO)
3333

3434
import qualified TensorFlow.Core as TF
35-
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt)
35+
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
3636
import qualified TensorFlow.Gradient as TF
37-
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
37+
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
3838
import qualified TensorFlow.Output as TF
3939
import qualified TensorFlow.Types as TF
4040
import qualified TensorFlow.Variable as TF
@@ -324,6 +324,7 @@ testPad =
324324
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
325325
V.fromList [2, 2, 3] @=? s
326326

327+
327328
testSqrt :: Test
328329
testSqrt = testCase "testSqrt" $ do
329330
[dx] <- TF.runSession $ do
@@ -332,6 +333,25 @@ testSqrt = testCase "testSqrt" $ do
332333
TF.gradients y [x] >>= TF.run
333334
V.fromList [2] @=? dx
334335

336+
testSlice :: Test
337+
testSlice =
338+
testCase "testSlice" $ do
339+
([dx], [s]) <-
340+
TF.runSession $ do
341+
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [2, 3, 4 :: Int64]
342+
(z :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 2 :: Int64]
343+
let y = TF.slice x (TF.constant (TF.Shape [3]) [1, 1, 1 :: Int32]) (TF.shape z)
344+
calculateGradWithShape y x
345+
let expected =
346+
[0, 0, 0, 0,
347+
0, 0, 0, 0,
348+
0, 0, 0, 0,
349+
0, 0, 0, 0,
350+
0, 1, 1, 0,
351+
0, 1, 1, 0]
352+
V.fromList expected @=? dx
353+
V.fromList [2, 3, 4] @=? s
354+
335355
testBatchToSpaceND :: Test
336356
testBatchToSpaceND =
337357
testCase "testBatchToSpaceND" $ do
@@ -526,6 +546,7 @@ main = defaultMain
526546
, testReshape
527547
, testPad
528548
, testSqrt
549+
, testSlice
529550
, testBatchToSpaceND
530551
, testSpaceToBatchND
531552
, testSqueeze

0 commit comments

Comments
 (0)