Skip to content

Commit 5721c6e

Browse files
committed
Fix open issues with slice-gradient
1 parent fcb01b3 commit 5721c6e

File tree

2 files changed

+16
-9
lines changed

2 files changed

+16
-9
lines changed

tensorflow-ops/src/TensorFlow/Gradient.hs

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ opGrad "Pad" _ [toT -> x, toT -> padPattern] [dz] =
711711
gradientSliceSize = shape (x :: Tensor Build Float)
712712

713713
-- Gradient for Slice
714-
-- Create an Nx2 padding where N ist the rank of (grad of) Slice and the first
714+
-- Create an Nx2 padding where N is the rank of (grad of) Slice and the first
715715
-- column represents how many zeros are to be prepended for each dimension, and the second
716716
-- column indicates how many zeros are appended.
717717
-- The number of zeros to prepend is the shape of the beginvec.
@@ -723,13 +723,13 @@ opGrad "Slice" _ [toT -> inputvec, toT -> beginvec, _] [dz] =
723723
[Just $ CoreOps.pad dz paddings, Nothing, Nothing]
724724
where
725725
v1 = vector [1 :: Int32]
726-
input_rank' = CoreOps.rank (inputvec :: Tensor Build Float)
727-
-- For some reason input_rank' has an empty shape
728-
input_rank = CoreOps.reshape input_rank' v1
729-
pad_shape = CoreOps.concat 0 [input_rank, v1]
730-
beforepad = CoreOps.reshape beginvec pad_shape
731-
afterpad = CoreOps.reshape (shape inputvec - shape dz - beginvec) pad_shape
732-
paddings = CoreOps.concat 1 [beforepad, afterpad]
726+
inputRank' = CoreOps.rank (inputvec :: Tensor Build Float)
727+
-- For some reason inputRank' has an empty shape
728+
inputRank = CoreOps.reshape inputRank' v1
729+
padShape = CoreOps.concat 0 [inputRank, v1]
730+
beforePad = CoreOps.reshape beginvec padShape
731+
afterPad = CoreOps.reshape (shape inputvec - shape dz - beginvec) padShape
732+
paddings = CoreOps.concat 1 [beforePad, afterPad]
733733

734734
-- TODO: This could be either Int32 or Int64.
735735
opGrad "BatchToSpaceND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> crops] [dz] =

tensorflow-ops/tests/GradientTest.hs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -342,7 +342,14 @@ testSlice =
342342
(z :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 2 :: Int64]
343343
let y = TF.slice x (TF.constant (TF.Shape [3]) [1, 1, 1 :: Int32]) (TF.shape z)
344344
calculateGradWithShape y x
345-
V.fromList [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 1, 0] @=? dx
345+
let expected =
346+
[0, 0, 0, 0,
347+
0, 0, 0, 0,
348+
0, 0, 0, 0,
349+
0, 0, 0, 0,
350+
0, 1, 1, 0,
351+
0, 1, 1, 0]
352+
V.fromList expected @=? dx
346353
V.fromList [2, 3, 4] @=? s
347354

348355
testBatchToSpaceND :: Test

0 commit comments

Comments
 (0)