Skip to content

Commit 96f1c88

Browse files
jcberentsenfkm3
authored andcommitted
Add gradient for ResizeBilinear (tensorflow#239)
1 parent 3cfd96e commit 96f1c88

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

tensorflow-ops/src/TensorFlow/Gradient.hs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,17 @@ opGrad "Tile" _ [toT -> x, toT -> multiples] [dz] =
817817
axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32)
818818
reshapedDz = CoreOps.reshape dz splitShape
819819

820+
opGrad "ResizeBilinear" nodeDef [toT -> x, _] [dz] =
821+
[ Just $ CoreOps.resizeBilinearGrad'
822+
(opAttr "align_corners" .~ align)
823+
(CoreOps.cast dz)
824+
x
825+
826+
, Nothing
827+
]
828+
where
829+
align = lookupAttr nodeDef "align_corners" :: Bool
830+
820831
opGrad "ZerosLike" _ _ _ = [Nothing]
821832
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
822833
where
@@ -894,6 +905,7 @@ numOutputs o =
894905
"Sum" -> 1
895906
"Tanh" -> 1
896907
"Tile" -> 1
908+
"ResizeBilinear" -> 1
897909
"Transpose" -> 1
898910
"TruncatedNormal" -> 1
899911
"VarHandleOp" -> 1

tensorflow-ops/tests/GradientTest.hs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
3232
import Control.Monad.IO.Class (liftIO)
3333

3434
import qualified TensorFlow.Core as TF
35-
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
35+
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
3636
import qualified TensorFlow.Gradient as TF
3737
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
3838
import qualified TensorFlow.Output as TF
@@ -429,6 +429,22 @@ testTile2DGrad = testCase "testTileGrad2D" $ do
429429
shapeX @=? (shapeDX :: V.Vector Int32)
430430
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
431431

432+
testResizeBilinearGrad :: Test
433+
testResizeBilinearGrad = testCase "testResizeBilinearGrad" $ do
434+
(dx, shapeDX, shapeX) <- TF.runSession $ do
435+
let shape = TF.vector [1, 2, 2, 1 :: Int32]
436+
x <- TF.render $ TF.fill shape (TF.scalar (1 :: Float))
437+
let outSize = TF.vector [4, 4 :: Int32]
438+
align = TF.opAttr "align_corners" .~ True
439+
y = TF.resizeBilinear' align x outSize
440+
441+
[dx] <- TF.gradients y [x]
442+
TF.run (dx, TF.shape dx, TF.shape x)
443+
shapeX @=? (shapeDX :: V.Vector Int32)
444+
let expect = V.fromList [4, 4, 4, 4 :: Float]
445+
near = 0.00001 > (V.sum $ V.zipWith (-) expect (dx :: V.Vector Float))
446+
near @=? True
447+
432448
matMulGradient :: Test
433449
matMulGradient = testCase "matMulGradients" $ do
434450

@@ -553,6 +569,7 @@ main = defaultMain
553569
, testFillGrad
554570
, testTileGrad
555571
, testTile2DGrad
572+
, testResizeBilinearGrad
556573
, matMulGradient
557574
, matMulGradGrad
558575
, matMulTransposeGradient (False, False)

0 commit comments

Comments
 (0)