Skip to content

Commit 95c6b6f

Browse files
rikvdkleijfkm3
authored andcommitted
Added support for ExpandDims gradient. (tensorflow#224)
1 parent 9150150 commit 95c6b6f

File tree

2 files changed

+35
-2
lines changed

2 files changed

+35
-2
lines changed

tensorflow-ops/src/TensorFlow/Gradient.hs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -691,8 +691,8 @@ opGrad "MaxPool" nodeDef [toT -> x] [dz] =
691691
padding = lookupAttr nodeDef "padding" :: ByteString
692692
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
693693

694-
opGrad "Reshape" _ [toT -> x, _] [dz] =
695-
[Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
694+
opGrad "Reshape" _ [toT -> x, _] [dz] = [Just $ reshape dz $ shape (x :: Tensor Build a), Nothing]
695+
opGrad "ExpandDims" n xs@[toT -> _, _] dzs@[_] = opGrad "Reshape" n xs dzs
696696

697697
opGrad "OneHot" _ _ _ = [Nothing, Nothing, Nothing, Nothing]
698698
opGrad "TruncatedNormal" _ _ _ = [Nothing]
@@ -810,6 +810,7 @@ numOutputs o =
810810
"DynamicPartition" ->
811811
fromIntegral (lookupAttr o "num_partitions" :: Int64)
812812
"Exp" -> 1
813+
"ExpandDims" -> 1
813814
"Gather" -> 1
814815
"LabelClasses" -> 1
815816
"LabelWeights" -> 1

tensorflow-ops/tests/GradientTest.hs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ import Proto.Tensorflow.Core.Framework.Graph_Fields (node)
4343
import Proto.Tensorflow.Core.Framework.NodeDef_Fields (op)
4444

4545
import qualified Data.ByteString.Char8 as BS
46+
import TensorFlow.Session (SessionT)
4647

4748
testGradientSimple :: Test
4849
testGradientSimple = testCase "testGradientSimple" $ do
@@ -290,6 +291,35 @@ testTanhGrad = testCase "testTanhGrad" $ do
290291
TF.gradients y [x] >>= TF.run
291292
V.fromList [1] @=? dx
292293

294+
testExpandDims :: Test
295+
testExpandDims =
296+
testCase "testExpandDims" $ do
297+
([dx], [s]) <-
298+
TF.runSession $ do
299+
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [1, 2, 3 :: Int64]
300+
let y = TF.expandDims x $ TF.constant (TF.Shape [1]) [0 :: Int32]
301+
calculateGradWithShape y x
302+
V.fromList [1, 1, 1, 1, 1, 1] @=? dx
303+
V.fromList [1, 2, 3] @=? s
304+
305+
testReshape :: Test
306+
testReshape =
307+
testCase "testReshape" $ do
308+
([dx], [s]) <-
309+
TF.runSession $ do
310+
(x :: TF.Tensor TF.Value Float) <- TF.render $ TF.zeros $ TF.Shape [2, 2 :: Int64]
311+
let y = TF.reshape x $ TF.constant (TF.Shape [2]) [1, 4 :: Int32]
312+
calculateGradWithShape y x
313+
V.fromList [1, 1, 1, 1] @=? dx
314+
V.fromList [2, 2] @=? s
315+
316+
calculateGradWithShape :: TF.Tensor TF.Build Float -> TF.Tensor TF.Value Float -> SessionT IO ([V.Vector Float], [V.Vector Int32])
317+
calculateGradWithShape y x = do
318+
gs <- TF.gradients y [x]
319+
xs <- TF.run gs
320+
(shapes :: [V.Vector Int32]) <- mapM (TF.run . TF.shape) gs
321+
return (xs, shapes)
322+
293323
testFillGrad :: Test
294324
testFillGrad = testCase "testFillGrad" $ do
295325
[dx] <- TF.runSession $ do
@@ -436,6 +466,8 @@ main = defaultMain
436466
, testReluGrad
437467
, testReluGradGrad
438468
, testTanhGrad
469+
, testExpandDims
470+
, testReshape
439471
, testFillGrad
440472
, testTileGrad
441473
, testTile2DGrad

0 commit comments

Comments
 (0)