Skip to content

Commit e4acd69

Browse files
rikvdkleijfkm3
authored andcommitted
Support gradients of pad, squeeze, spaceToBatchND, and batchToSpaceND (tensorflow#226)
1 parent 95c6b6f commit e4acd69

File tree

2 files changed

+83
-3
lines changed

2 files changed

+83
-3
lines changed

tensorflow-ops/src/TensorFlow/Gradient.hs

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
{-# LANGUAGE ScopedTypeVariables #-}
2121
{-# LANGUAGE TypeFamilies #-}
2222
{-# LANGUAGE ViewPatterns #-}
23+
{-# LANGUAGE TypeApplications #-}
2324

2425
module TensorFlow.Gradient
2526
( GradientCompatible
@@ -693,6 +694,29 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
693694

694695
opGrad "Reshape" _ [toT -> x, _] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
695696
opGrad "ExpandDims" n xs@[toT -> _, _] dzs@[_] = opGrad "Reshape" n xs dzs
697+
opGrad "Squeeze" _ [toT -> x] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a)]
698+
opGrad "Pad" _ [toT -> x, toT -> padPattern] [dz] =
699+
[Just $ CoreOps.slice dz gradientSliceBegin gradientSliceSize, Nothing]
700+
where
701+
v1 = vector [1]
702+
-- For some reason rankx' has an empty shape
703+
rankx' = CoreOps.rank (x :: Tensor Build Float)
704+
rankx = CoreOps.reshape rankx' v1
705+
-- Size of column that is sliced from pad pattern
706+
padPatternSliceSize = CoreOps.concat 0 [rankx, v1]
707+
padPatternSliceBegin = vector [0, 0]
708+
padPatternSliced :: Tensor Build Int32 = CoreOps.slice padPattern padPatternSliceBegin padPatternSliceSize
709+
-- The slice of the pad pattern has the same rank as the pad pattern itself
710+
gradientSliceBegin = CoreOps.reshape padPatternSliced rankx
711+
gradientSliceSize = shape (x :: Tensor Build Float)
712+
713+
-- TODO: This could be either Int32 or Int64.
714+
opGrad "BatchToSpaceND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> crops] [dz] =
715+
[Just $ CoreOps.spaceToBatchND dz blockShape crops, Nothing, Nothing]
716+
717+
-- TODO: This could be either Int32 or Int64.
718+
opGrad "SpaceToBatchND" _ [_, toT @Int32 -> blockShape, toT @Int32 -> paddings] [dz] =
719+
[Just $ CoreOps.batchToSpaceND dz blockShape paddings, Nothing, Nothing]
696720

697721
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
698722
opGrad "TruncatedNormal" _ _ _ = [Nothing]
@@ -800,6 +824,7 @@ numOutputs o =
800824
"Abs" -> 1
801825
"Add" -> 1
802826
"AddN" -> 1
827+
"BatchToSpaceND" -> 1
803828
"Cast" -> 1
804829
"Const" -> 1
805830
"Concat" -> 1
@@ -823,6 +848,7 @@ numOutputs o =
823848
"Min" -> 1
824849
"Mul" -> 1
825850
"Neg" -> 1
851+
"Pad" -> 1
826852
"Placeholder" -> 1
827853
"OneHot" -> 1
828854
"ReadVariableOp" -> 1
@@ -833,8 +859,10 @@ numOutputs o =
833859
"Select" -> 1
834860
"Size" -> 1
835861
"SoftmaxCrossEntropyWithLogits" -> 2
836-
"Square" -> 1
862+
"SpaceToBatchND" -> 1
837863
"SparseSegmentSum" -> 1
864+
"Square" -> 1
865+
"Squeeze" -> 1
838866
"Sub" -> 1
839867
"Sum" -> 1
840868
"Tanh" -> 1

tensorflow-ops/tests/GradientTest.hs

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
3232
import Control.Monad.IO.Class (liftIO)
3333

3434
import qualified TensorFlow.Core as TF
35-
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile)
35+
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze)
3636
import qualified TensorFlow.Gradient as TF
3737
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
3838
import qualified TensorFlow.Output as TF
@@ -313,6 +313,54 @@ testReshape =
313313
V.fromList [1, 1, 1, 1] @=? dx
314314
V.fromList [2, 2] @=? s
315315

316+
testPad :: Test
317+
testPad =
318+
testCase "testPad" $ do
319+
([dx], [s]) <-
320+
TF.runSession $ do
321+
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [2, 2, 3 :: Int64]
322+
let y = TF.pad x $ TF.constant (TF.Shape [3, 2]) [1, 4, 1, 1, 2, 3 :: Int32]
323+
calculateGradWithShape y x
324+
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
325+
V.fromList [2, 2, 3] @=? s
326+
327+
testBatchToSpaceND :: Test
328+
testBatchToSpaceND =
329+
testCase "testBatchToSpaceND" $ do
330+
([dx], [s]) <-
331+
TF.runSession $ do
332+
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.constant (TF.Shape [4, 1, 1, 1 :: Int64]) [1, 2, 3, 4]
333+
shape <- TF.render $ TF.vector [2, 2 :: Int32]
334+
crops <- TF.render $ TF.constant (TF.Shape [2, 2]) [0, 0, 0, 0 :: Int32]
335+
let y = TF.batchToSpaceND x shape crops
336+
calculateGradWithShape y x
337+
V.fromList [1, 1, 1, 1] @=? dx
338+
V.fromList [4, 1, 1, 1] @=? s
339+
340+
testSpaceToBatchND :: Test
341+
testSpaceToBatchND =
342+
testCase "testSpaceToBatchND" $ do
343+
([dx], [s]) <-
344+
TF.runSession $ do
345+
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.constant (TF.Shape [1, 2, 2, 1 :: Int64]) [1, 2, 3, 4]
346+
shape <- TF.render $ TF.vector [2, 2 :: Int32]
347+
paddings <- TF.render $ TF.constant (TF.Shape [2, 2]) [0, 0, 0, 0 :: Int32]
348+
let y = TF.spaceToBatchND x shape paddings
349+
calculateGradWithShape y x
350+
V.fromList [1, 1, 1, 1] @=? dx
351+
V.fromList [1, 2, 2, 1] @=? s
352+
353+
testSqueeze :: Test
354+
testSqueeze =
355+
testCase "testSqueeze" $ do
356+
([dx], [s]) <-
357+
TF.runSession $ do
358+
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 3 :: Int64]
359+
let y = TF.squeeze x
360+
calculateGradWithShape y x
361+
V.fromList [1, 1, 1, 1, 1, 1] @=? dx
362+
V.fromList [1, 2, 3] @=? s
363+
316364
calculateGradWithShape :: TF.Tensor TF.Build Float -> TF.Tensor TF.Value Float -> SessionT IO ([V.Vector Float], [V.Vector Int32])
317365
calculateGradWithShape y x = do
318366
gs <- TF.gradients y [x]
@@ -468,6 +516,10 @@ main = defaultMain
468516
, testTanhGrad
469517
, testExpandDims
470518
, testReshape
519+
, testPad
520+
, testBatchToSpaceND
521+
, testSpaceToBatchND
522+
, testSqueeze
471523
, testFillGrad
472524
, testTileGrad
473525
, testTile2DGrad
@@ -478,4 +530,4 @@ main = defaultMain
478530
, matMulTransposeGradient (True, False)
479531
, matMulTransposeGradient (True, True)
480532
, testConv2DBackpropInputGrad
481-
]
533+
]

0 commit comments

Comments
 (0)