Skip to content

Commit 1fbd5d4

Browse files
jcberentsenfkm3
authored andcommitted
Add gradients for DepthwiseConv2dNative (#240)
1 parent 4a2e46b commit 1fbd5d4

File tree

2 files changed

+82
-1
lines changed

2 files changed

+82
-1
lines changed

tensorflow-ops/src/TensorFlow/Gradient.hs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,41 @@ opGrad "Conv2DBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
694694
useCudnnOnGpu = lookupAttr nodeDef "use_cudnn_on_gpu" :: Bool
695695
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
696696

697+
opGrad "DepthwiseConv2dNative" nodeDef [toT -> x, toT -> y] [dz] =
698+
[ Just $ CoreOps.depthwiseConv2dNativeBackpropInput'
699+
((opAttr "strides" .~ strides)
700+
. (opAttr "padding" .~ padding)
701+
. (opAttr "data_format" .~ dataFormat))
702+
(shape x) y dz
703+
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
704+
((opAttr "strides" .~ strides)
705+
. (opAttr "padding" .~ padding)
706+
. (opAttr "data_format" .~ dataFormat))
707+
x (shape y) dz
708+
]
709+
where
710+
strides = lookupAttr nodeDef "strides" :: [Int64]
711+
padding = lookupAttr nodeDef "padding" :: ByteString
712+
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
713+
714+
opGrad "DepthwiseConv2dNativeBackpropInput" nodeDef [_, toT -> x, toT -> y] [dz] =
715+
[ Nothing
716+
, Just $ CoreOps.depthwiseConv2dNativeBackpropFilter'
717+
((opAttr "strides" .~ strides)
718+
. (opAttr "padding" .~ padding)
719+
. (opAttr "data_format" .~ dataFormat))
720+
dz (shape x) y
721+
, Just $ CoreOps.depthwiseConv2dNative'
722+
((opAttr "strides" .~ strides)
723+
. (opAttr "padding" .~ padding)
724+
. (opAttr "data_format" .~ dataFormat))
725+
dz x
726+
]
727+
where
728+
strides = lookupAttr nodeDef "strides" :: [Int64]
729+
padding = lookupAttr nodeDef "padding" :: ByteString
730+
dataFormat = lookupAttr nodeDef "data_format" :: ByteString
731+
697732
opGrad "MaxPool" nodeDef [toT -> x] [dz] =
698733
[ Just $ CoreOps.maxPoolGrad'
699734
((opAttr "ksize" .~ ksize)
@@ -882,6 +917,8 @@ numOutputs o =
882917
"Concat" -> 1
883918
"Conv2D" -> 1
884919
"Conv2DBackpropInput" -> 1
920+
"DepthwiseConv2dNative" -> 1
921+
"DepthwiseConv2dNativeBackpropInput" -> 1
885922
"Div" -> 1
886923
"DynamicStitch" -> 1
887924
"DynamicPartition" ->

tensorflow-ops/tests/GradientTest.hs

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
3333
import Control.Monad.IO.Class (liftIO)
3434

3535
import qualified TensorFlow.Core as TF
36-
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag)
36+
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput')
3737
import qualified TensorFlow.Gradient as TF
3838
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
3939
import qualified TensorFlow.Output as TF
@@ -596,6 +596,7 @@ transAttrs :: (TF.Attribute a,
596596
transAttrs a b =
597597
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
598598

599+
-- TODO check gradient with regard to filter also
599600
testConv2DBackpropInputGrad :: Test
600601
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
601602
(dx, shapeDX, shapeX) <- TF.runSession $ do
@@ -617,6 +618,47 @@ testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
617618
shapeX @=? (shapeDX :: V.Vector Int32)
618619
V.fromList [4::Float] @=? (dx :: V.Vector Float)
619620

621+
testDepthwiseConv2dGrad :: Test
622+
testDepthwiseConv2dGrad = testCase "testDepthwiseConv2dGrad" $ do
623+
(dx, shapeDX, shapeX) <- TF.runSession $ do
624+
let conv_input_shape = TF.vector [1, 2, 2, 1 :: Int32]
625+
x <- TF.render $ TF.fill conv_input_shape (TF.scalar (2 :: Float))
626+
627+
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
628+
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
629+
let y = TF.depthwiseConv2dNative'
630+
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
631+
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
632+
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
633+
)
634+
x filter'
635+
636+
[dx] <- TF.gradients y [x]
637+
TF.run (dx, TF.shape dx, TF.shape x)
638+
shapeX @=? (shapeDX :: V.Vector Int32)
639+
V.fromList [1, 1, 1, 1 :: Float] @=? (dx :: V.Vector Float)
640+
641+
-- TODO also test filter gradient
642+
testDepthwiseConv2dBackpropInputGrad :: Test
643+
testDepthwiseConv2dBackpropInputGrad = testCase "testDepthwiseConv2dBackpropInputGrad" $ do
644+
(dx, shapeDX, shapeX) <- TF.runSession $ do
645+
let conv_input_shape = TF.vector [1, 2, 2, 1 :: Int32]
646+
let conv_out_shape = TF.vector [1, 1, 1, 1 :: Int32] -- [batch, h, w, out_channels]
647+
x <- TF.render $ TF.fill conv_out_shape (TF.scalar (1::Float))
648+
649+
let filterShape = TF.vector [2, 2, 1, 1 :: Int32]
650+
filter' <- TF.render $ TF.fill filterShape (TF.scalar (1 :: Float))
651+
let y = TF.depthwiseConv2dNativeBackpropInput'
652+
( (TF.opAttr "strides" .~ [1 :: Int64, 1, 1, 1])
653+
. (TF.opAttr "padding" .~ (BS.pack "VALID"))
654+
. (TF.opAttr "data_format" .~ (BS.pack "NHWC"))
655+
)
656+
conv_input_shape filter' x
657+
658+
[dx] <- TF.gradients y [x]
659+
TF.run (dx, TF.shape dx, TF.shape x)
660+
shapeX @=? (shapeDX :: V.Vector Int32)
661+
V.fromList [4::Float] @=? (dx :: V.Vector Float)
620662

621663
main :: IO ()
622664
main = defaultMain
@@ -658,4 +700,6 @@ main = defaultMain
658700
, matMulTransposeGradient (True, False)
659701
, matMulTransposeGradient (True, True)
660702
, testConv2DBackpropInputGrad
703+
, testDepthwiseConv2dGrad
704+
, testDepthwiseConv2dBackpropInputGrad
661705
]

0 commit comments

Comments
 (0)