Skip to content

Commit d741c3e

Browse files
rschlotterbeckfkm3
authored andcommitted
Add gradient for batchMatMul (tensorflow#246)
1 parent c811037 commit d741c3e

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

tensorflow-ops/src/TensorFlow/Gradient.hs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -650,6 +650,25 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
650650
[ Just $ matMul' (transAttrs True True) y dz
651651
, Just $ matMul' (transAttrs True True) dz x]
652652

653+
opGrad "BatchMatMul" nodeDef [toT -> x, toT -> y] [dz] =
654+
let adjX = lookupAttr nodeDef "adj_x"
655+
adjY = lookupAttr nodeDef "adj_y"
656+
adjAttrs a b =
657+
(opAttr "adj_x" .~ a) . (opAttr "adj_y" .~ b)
658+
in case (adjX, adjY) of
659+
(False, False) ->
660+
[ Just $ CoreOps.batchMatMul' (adjAttrs False True) dz y
661+
, Just $ CoreOps.batchMatMul' (adjAttrs True False) x dz]
662+
(False, True) ->
663+
[ Just $ CoreOps.batchMatMul dz y
664+
, Just $ CoreOps.batchMatMul' (adjAttrs True False) dz x]
665+
(True, False) ->
666+
[ Just $ CoreOps.batchMatMul' (adjAttrs False True) y dz
667+
, Just $ CoreOps.batchMatMul x dz]
668+
(True, True) ->
669+
[ Just $ CoreOps.batchMatMul' (adjAttrs True True) y dz
670+
, Just $ CoreOps.batchMatMul' (adjAttrs True True) dz x]
671+
653672
opGrad "Transpose" _ [_, toT -> p] [dz] =
654673
[ Just $ CoreOps.transpose dz
655674
(CoreOps.invertPermutation p :: Tensor Build Int32)
@@ -915,6 +934,7 @@ numOutputs o =
915934
"Add" -> 1
916935
"AddN" -> 1
917936
"BatchToSpaceND" -> 1
937+
"BatchMatMul" -> 1
918938
"Cast" -> 1
919939
"Const" -> 1
920940
"Concat" -> 1

tensorflow-ops/tests/GradientTest.hs

Lines changed: 84 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
3333
import Control.Monad.IO.Class (liftIO)
3434

3535
import qualified TensorFlow.Core as TF
36-
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput')
36+
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput', batchMatMul, batchMatMul', sum, conjugateTranspose)
3737
import qualified TensorFlow.Gradient as TF
3838
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
3939
import qualified TensorFlow.Output as TF
@@ -604,6 +604,83 @@ transAttrs :: (TF.Attribute a,
604604
transAttrs a b =
605605
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)
606606

607+
608+
batchMatMulGradient :: Test
609+
batchMatMulGradient = testCase "batchMatMulGradients" $ do
610+
611+
let dfBuild = do
612+
x <- TF.render $ TF.zeros $ TF.Shape [2,3, 1 :: Int64]
613+
w <- TF.zeroInitializedVariable $ TF.Shape [2,1, 2 :: Int64]
614+
let f = x `TF.batchMatMul` TF.readValue w :: TF.Tensor TF.Build Float
615+
dfs <- TF.gradients f [x]
616+
return (x, dfs)
617+
618+
(xShape, dxShape) <- TF.runSession $ do
619+
(x, [dx]) <- TF.build dfBuild
620+
TF.run (TF.shape x, TF.shape dx)
621+
622+
assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)
623+
624+
625+
-- test that gradient of batchMatMul can be taken gradient of
626+
batchMatMulGradGrad :: Test
627+
batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do
628+
let width = 2 :: Int64
629+
height = 3 :: Int64
630+
batch = 4 :: Int64
631+
632+
let tower = do
633+
x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1]
634+
w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width]
635+
let f = x `TF.batchMatMul` TF.readValue w
636+
[dfdx] <- TF.gradients f [x]
637+
let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32])
638+
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
639+
return [TF.readValue w, TF.expr dfdw]
640+
641+
TF.runSession $ do
642+
[w, dfdw] <- TF.build tower
643+
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
644+
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)
645+
646+
let step = w `TF.add` dfdw
647+
w0 <- TF.run step
648+
liftIO $ V.fromList [3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0 :: Float] @=? w0
649+
650+
651+
-- test that gradient of batchMatMul deals correctly with adj_x and adj_y
652+
batchMatMulAdjointGradient :: (Bool, Bool) -> Test
653+
batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ show axw) $ do
654+
let (adjX, adjW) = axw
655+
656+
let dfBuild = do
657+
let xShape = TF.Shape [2, 3, 1 :: Int64]
658+
let xZeros = TF.zeros xShape
659+
x <- TF.render $ if adjX then TF.conjugateTranspose xZeros (TF.vector [0, 2, 1 :: Int32]) else xZeros
660+
variable <- TF.zeroInitializedVariable $ TF.Shape [2, 1, 2 :: Int64]
661+
let wv = if adjW then TF.conjugateTranspose (TF.readValue variable) (TF.vector [0, 2, 1 :: Int32]) else TF.readValue variable
662+
let f = TF.batchMatMul' (adjAttrs adjX adjW) x wv :: TF.Tensor TF.Build Float
663+
w <- TF.render wv
664+
ds <- TF.gradients f [x, w]
665+
return (x, w, ds)
666+
667+
TF.runSession $ do
668+
(x, w, [dx, dw]) <- TF.build dfBuild
669+
xShape <- TF.run $ TF.shape x
670+
dxShape <- TF.run $ TF.shape dx
671+
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)
672+
673+
wShape <- TF.run $ TF.shape w
674+
dwShape <- TF.run $ TF.shape dw
675+
liftIO $ assertEqual "wShape must match dwShape" wShape (dwShape :: V.Vector Int32)
676+
677+
adjAttrs :: (TF.Attribute x,
678+
TF.Attribute y) =>
679+
x -> y -> TF.OpDef -> TF.OpDef
680+
adjAttrs x y =
681+
(TF.opAttr "adj_x" .~ x) . (TF.opAttr "adj_y" .~ y)
682+
683+
607684
-- TODO check gradient with regard to filter also
608685
testConv2DBackpropInputGrad :: Test
609686
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
@@ -708,6 +785,12 @@ main = defaultMain
708785
, matMulTransposeGradient (False, True)
709786
, matMulTransposeGradient (True, False)
710787
, matMulTransposeGradient (True, True)
788+
, batchMatMulGradient
789+
, batchMatMulGradGrad
790+
, batchMatMulAdjointGradient (False, False)
791+
, batchMatMulAdjointGradient (False, True)
792+
, batchMatMulAdjointGradient (True, False)
793+
, batchMatMulAdjointGradient (True, True)
711794
, testConv2DBackpropInputGrad
712795
, testDepthwiseConv2dGrad
713796
, testDepthwiseConv2dBackpropInputGrad

0 commit comments

Comments
 (0)