@@ -33,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
33
33
import Control.Monad.IO.Class (liftIO )
34
34
35
35
import qualified TensorFlow.Core as TF
36
- import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput' , max , maximum , resizeBilinear' , tile , pad , batchToSpaceND , spaceToBatchND , squeeze , sqrt , slice , shape , diag )
36
+ import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput' , max , maximum , resizeBilinear' , tile , pad , batchToSpaceND , spaceToBatchND , squeeze , sqrt , slice , shape , diag , depthwiseConv2dNative' , depthwiseConv2dNativeBackpropInput' )
37
37
import qualified TensorFlow.Gradient as TF
38
38
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable , shape )
39
39
import qualified TensorFlow.Output as TF
@@ -596,6 +596,7 @@ transAttrs :: (TF.Attribute a,
596
596
transAttrs a b =
597
597
(TF. opAttr " transpose_a" .~ a) . (TF. opAttr " transpose_b" .~ b)
598
598
599
+ -- TODO check gradient with regard to filter also
599
600
testConv2DBackpropInputGrad :: Test
600
601
testConv2DBackpropInputGrad = testCase " testConv2DBackpropInputGrad" $ do
601
602
(dx, shapeDX, shapeX) <- TF. runSession $ do
@@ -617,6 +618,47 @@ testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
617
618
shapeX @=? (shapeDX :: V. Vector Int32 )
618
619
V. fromList [4 :: Float ] @=? (dx :: V. Vector Float )
619
620
621
+ testDepthwiseConv2dGrad :: Test
622
+ testDepthwiseConv2dGrad = testCase " testDepthwiseConv2dGrad" $ do
623
+ (dx, shapeDX, shapeX) <- TF. runSession $ do
624
+ let conv_input_shape = TF. vector [1 , 2 , 2 , 1 :: Int32 ]
625
+ x <- TF. render $ TF. fill conv_input_shape (TF. scalar (2 :: Float ))
626
+
627
+ let filterShape = TF. vector [2 , 2 , 1 , 1 :: Int32 ]
628
+ filter' <- TF. render $ TF. fill filterShape (TF. scalar (1 :: Float ))
629
+ let y = TF. depthwiseConv2dNative'
630
+ ( (TF. opAttr " strides" .~ [1 :: Int64 , 1 , 1 , 1 ])
631
+ . (TF. opAttr " padding" .~ (BS. pack " VALID" ))
632
+ . (TF. opAttr " data_format" .~ (BS. pack " NHWC" ))
633
+ )
634
+ x filter'
635
+
636
+ [dx] <- TF. gradients y [x]
637
+ TF. run (dx, TF. shape dx, TF. shape x)
638
+ shapeX @=? (shapeDX :: V. Vector Int32 )
639
+ V. fromList [1 , 1 , 1 , 1 :: Float ] @=? (dx :: V. Vector Float )
640
+
641
+ -- TODO also test filter gradient
642
+ testDepthwiseConv2dBackpropInputGrad :: Test
643
+ testDepthwiseConv2dBackpropInputGrad = testCase " testDepthwiseConv2dBackpropInputGrad" $ do
644
+ (dx, shapeDX, shapeX) <- TF. runSession $ do
645
+ let conv_input_shape = TF. vector [1 , 2 , 2 , 1 :: Int32 ]
646
+ let conv_out_shape = TF. vector [1 , 1 , 1 , 1 :: Int32 ] -- [batch, h, w, out_channels]
647
+ x <- TF. render $ TF. fill conv_out_shape (TF. scalar (1 :: Float ))
648
+
649
+ let filterShape = TF. vector [2 , 2 , 1 , 1 :: Int32 ]
650
+ filter' <- TF. render $ TF. fill filterShape (TF. scalar (1 :: Float ))
651
+ let y = TF. depthwiseConv2dNativeBackpropInput'
652
+ ( (TF. opAttr " strides" .~ [1 :: Int64 , 1 , 1 , 1 ])
653
+ . (TF. opAttr " padding" .~ (BS. pack " VALID" ))
654
+ . (TF. opAttr " data_format" .~ (BS. pack " NHWC" ))
655
+ )
656
+ conv_input_shape filter' x
657
+
658
+ [dx] <- TF. gradients y [x]
659
+ TF. run (dx, TF. shape dx, TF. shape x)
660
+ shapeX @=? (shapeDX :: V. Vector Int32 )
661
+ V. fromList [4 :: Float ] @=? (dx :: V. Vector Float )
620
662
621
663
main :: IO ()
622
664
main = defaultMain
@@ -658,4 +700,6 @@ main = defaultMain
658
700
, matMulTransposeGradient (True , False )
659
701
, matMulTransposeGradient (True , True )
660
702
, testConv2DBackpropInputGrad
703
+ , testDepthwiseConv2dGrad
704
+ , testDepthwiseConv2dBackpropInputGrad
661
705
]
0 commit comments