@@ -33,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
33
33
import Control.Monad.IO.Class (liftIO )
34
34
35
35
import qualified TensorFlow.Core as TF
36
- import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput' , max , maximum , resizeBilinear' , tile , pad , batchToSpaceND , spaceToBatchND , squeeze , sqrt , slice , shape , diag , depthwiseConv2dNative' , depthwiseConv2dNativeBackpropInput' )
36
+ import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput' , max , maximum , resizeBilinear' , tile , pad , batchToSpaceND , spaceToBatchND , squeeze , sqrt , slice , shape , diag , depthwiseConv2dNative' , depthwiseConv2dNativeBackpropInput' , batchMatMul , batchMatMul' , sum , conjugateTranspose )
37
37
import qualified TensorFlow.Gradient as TF
38
38
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable , shape )
39
39
import qualified TensorFlow.Output as TF
@@ -604,6 +604,83 @@ transAttrs :: (TF.Attribute a,
604
604
transAttrs a b =
605
605
(TF. opAttr " transpose_a" .~ a) . (TF. opAttr " transpose_b" .~ b)
606
606
607
+
608
+ batchMatMulGradient :: Test
609
+ batchMatMulGradient = testCase " batchMatMulGradients" $ do
610
+
611
+ let dfBuild = do
612
+ x <- TF. render $ TF. zeros $ TF. Shape [2 ,3 , 1 :: Int64 ]
613
+ w <- TF. zeroInitializedVariable $ TF. Shape [2 ,1 , 2 :: Int64 ]
614
+ let f = x `TF.batchMatMul` TF. readValue w :: TF. Tensor TF. Build Float
615
+ dfs <- TF. gradients f [x]
616
+ return (x, dfs)
617
+
618
+ (xShape, dxShape) <- TF. runSession $ do
619
+ (x, [dx]) <- TF. build dfBuild
620
+ TF. run (TF. shape x, TF. shape dx)
621
+
622
+ assertEqual " Shape of gradient must match shape of input" xShape (dxShape :: V. Vector Int32 )
623
+
624
+
625
+ -- test that gradient of batchMatMul can be taken gradient of
626
+ batchMatMulGradGrad :: Test
627
+ batchMatMulGradGrad = testCase " batchMatMulGradGrad" $ do
628
+ let width = 2 :: Int64
629
+ height = 3 :: Int64
630
+ batch = 4 :: Int64
631
+
632
+ let tower = do
633
+ x <- TF. render $ TF. zeros $ TF. Shape [batch, height, 1 ]
634
+ w <- TF. zeroInitializedVariable $ TF. Shape [batch, 1 , width]
635
+ let f = x `TF.batchMatMul` TF. readValue w
636
+ [dfdx] <- TF. gradients f [x]
637
+ let f'x = TF. sum dfdx (TF. vector [1 , 2 :: Int32 ])
638
+ [dfdw] <- TF. gradients f'x [w] -- take gradient again (this time over w)
639
+ return [TF. readValue w, TF. expr dfdw]
640
+
641
+ TF. runSession $ do
642
+ [w, dfdw] <- TF. build tower
643
+ (wShape, dfdwShape) <- TF. run (TF. shape w, TF. shape dfdw)
644
+ liftIO $ assertEqual " Shape of gradient must match input" wShape (dfdwShape :: V. Vector Int32 )
645
+
646
+ let step = w `TF.add` dfdw
647
+ w0 <- TF. run step
648
+ liftIO $ V. fromList [3.0 ,3.0 ,3.0 ,3.0 ,3.0 ,3.0 ,3.0 ,3.0 :: Float ] @=? w0
649
+
650
+
651
+ -- test that gradient of batchMatMul deals correctly with adj_x and adj_y
652
+ batchMatMulAdjointGradient :: (Bool , Bool ) -> Test
653
+ batchMatMulAdjointGradient axw = testCase (" batchMatMulAdjointGradients " ++ show axw) $ do
654
+ let (adjX, adjW) = axw
655
+
656
+ let dfBuild = do
657
+ let xShape = TF. Shape [2 , 3 , 1 :: Int64 ]
658
+ let xZeros = TF. zeros xShape
659
+ x <- TF. render $ if adjX then TF. conjugateTranspose xZeros (TF. vector [0 , 2 , 1 :: Int32 ]) else xZeros
660
+ variable <- TF. zeroInitializedVariable $ TF. Shape [2 , 1 , 2 :: Int64 ]
661
+ let wv = if adjW then TF. conjugateTranspose (TF. readValue variable) (TF. vector [0 , 2 , 1 :: Int32 ]) else TF. readValue variable
662
+ let f = TF. batchMatMul' (adjAttrs adjX adjW) x wv :: TF. Tensor TF. Build Float
663
+ w <- TF. render wv
664
+ ds <- TF. gradients f [x, w]
665
+ return (x, w, ds)
666
+
667
+ TF. runSession $ do
668
+ (x, w, [dx, dw]) <- TF. build dfBuild
669
+ xShape <- TF. run $ TF. shape x
670
+ dxShape <- TF. run $ TF. shape dx
671
+ liftIO $ assertEqual " xShape must match dxShape" xShape (dxShape :: V. Vector Int32 )
672
+
673
+ wShape <- TF. run $ TF. shape w
674
+ dwShape <- TF. run $ TF. shape dw
675
+ liftIO $ assertEqual " wShape must match dwShape" wShape (dwShape :: V. Vector Int32 )
676
+
677
+ adjAttrs :: (TF. Attribute x ,
678
+ TF. Attribute y ) =>
679
+ x -> y -> TF. OpDef -> TF. OpDef
680
+ adjAttrs x y =
681
+ (TF. opAttr " adj_x" .~ x) . (TF. opAttr " adj_y" .~ y)
682
+
683
+
607
684
-- TODO check gradient with regard to filter also
608
685
testConv2DBackpropInputGrad :: Test
609
686
testConv2DBackpropInputGrad = testCase " testConv2DBackpropInputGrad" $ do
@@ -708,6 +785,12 @@ main = defaultMain
708
785
, matMulTransposeGradient (False , True )
709
786
, matMulTransposeGradient (True , False )
710
787
, matMulTransposeGradient (True , True )
788
+ , batchMatMulGradient
789
+ , batchMatMulGradGrad
790
+ , batchMatMulAdjointGradient (False , False )
791
+ , batchMatMulAdjointGradient (False , True )
792
+ , batchMatMulAdjointGradient (True , False )
793
+ , batchMatMulAdjointGradient (True , True )
711
794
, testConv2DBackpropInputGrad
712
795
, testDepthwiseConv2dGrad
713
796
, testDepthwiseConv2dBackpropInputGrad
0 commit comments