Skip to content

Commit 5e6b940

Browse files
committed
Add gradient for ResizeBilinear
1 parent 3cfd96e commit 5e6b940

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

tensorflow-ops/src/TensorFlow/Gradient.hs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,18 @@ opGrad "Tile" _ [toT -> x, toT -> multiples] [dz] =
817817
axes = CoreOps.range 0 (CoreOps.size splitShape) (2 :: Tensor Build Int32)
818818
reshapedDz = CoreOps.reshape dz splitShape
819819

820+
opGrad "ResizeBilinear" nodeDef [toT -> x, _] [dz] =
821+
[ Just $ CoreOps.cast $ CoreOps.resizeBilinear'
822+
(opAttr "align_corners" .~ align)
823+
dz
824+
xSize32
825+
826+
, Nothing
827+
]
828+
where
829+
xSize32 = flatSlice (shape (x :: Tensor Build a)) 1 2
830+
align = lookupAttr nodeDef "align_corners" :: Bool
831+
820832
opGrad "ZerosLike" _ _ _ = [Nothing]
821833
opGrad "Fill" _ _ [dz] = [Nothing, Just $ sum dz rx]
822834
where
@@ -894,6 +906,7 @@ numOutputs o =
894906
"Sum" -> 1
895907
"Tanh" -> 1
896908
"Tile" -> 1
909+
"ResizeBilinear" -> 1
897910
"Transpose" -> 1
898911
"TruncatedNormal" -> 1
899912
"VarHandleOp" -> 1

tensorflow-ops/tests/GradientTest.hs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
3232
import Control.Monad.IO.Class (liftIO)
3333

3434
import qualified TensorFlow.Core as TF
35-
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
35+
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape)
3636
import qualified TensorFlow.Gradient as TF
3737
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
3838
import qualified TensorFlow.Output as TF
@@ -429,6 +429,19 @@ testTile2DGrad = testCase "testTileGrad2D" $ do
429429
shapeX @=? (shapeDX :: V.Vector Int32)
430430
V.fromList [6, 6, 6, 6, 6, 6::Float] @=? (dx :: V.Vector Float)
431431

432+
testResizeBilinearGrad :: Test
433+
testResizeBilinearGrad = testCase "testResizeBilinearGrad" $ do
434+
(dx, shapeDX, shapeX) <- TF.runSession $ do
435+
let shape = TF.vector [1, 3, 2, 1 :: Int32]
436+
x <- TF.render $ TF.fill shape (TF.scalar (1::Float))
437+
let outSize = TF.vector [6, 4 :: Int32]
438+
let y = TF.resizeBilinear x outSize
439+
440+
[dx] <- TF.gradients y [x]
441+
TF.run (dx, TF.shape dx, TF.shape x)
442+
shapeX @=? (shapeDX :: V.Vector Int32)
443+
V.fromList [1, 1, 1, 1, 1, 1::Float] @=? (dx :: V.Vector Float)
444+
432445
matMulGradient :: Test
433446
matMulGradient = testCase "matMulGradients" $ do
434447

@@ -553,6 +566,7 @@ main = defaultMain
553566
, testFillGrad
554567
, testTileGrad
555568
, testTile2DGrad
569+
, testResizeBilinearGrad
556570
, matMulGradient
557571
, matMulGradGrad
558572
, matMulTransposeGradient (False, False)

0 commit comments

Comments
 (0)