Skip to content

Add gradient for batchMatMul #246

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions tensorflow-ops/src/TensorFlow/Gradient.hs
Original file line number Diff line number Diff line change
Expand Up @@ -647,6 +647,25 @@ opGrad "MatMul" nodeDef [toT -> x, toT -> y] [dz] =
[ Just $ matMul' (transAttrs True True) y dz
, Just $ matMul' (transAttrs True True) dz x]

opGrad "BatchMatMul" nodeDef [toT -> x, toT -> y] [dz] =
let adjX = lookupAttr nodeDef "adj_x"
adjY = lookupAttr nodeDef "adj_y"
adjAttrs a b =
(opAttr "adj_x" .~ a) . (opAttr "adj_y" .~ b)
in case (adjX, adjY) of
(False, False) ->
[ Just $ CoreOps.batchMatMul' (adjAttrs False True) dz y
, Just $ CoreOps.batchMatMul' (adjAttrs True False) x dz]
(False, True) ->
[ Just $ CoreOps.batchMatMul dz y
, Just $ CoreOps.batchMatMul' (adjAttrs True False) dz x]
(True, False) ->
[ Just $ CoreOps.batchMatMul' (adjAttrs False True) y dz
, Just $ CoreOps.batchMatMul x dz]
(True, True) ->
[ Just $ CoreOps.batchMatMul' (adjAttrs True True) y dz
, Just $ CoreOps.batchMatMul' (adjAttrs True True) dz x]

opGrad "Transpose" _ [_, toT -> p] [dz] =
[ Just $ CoreOps.transpose dz
(CoreOps.invertPermutation p :: Tensor Build Int32)
Expand Down Expand Up @@ -912,6 +931,7 @@ numOutputs o =
"Add" -> 1
"AddN" -> 1
"BatchToSpaceND" -> 1
"BatchMatMul" -> 1
"Cast" -> 1
"Const" -> 1
"Concat" -> 1
Expand Down
85 changes: 84 additions & 1 deletion tensorflow-ops/tests/GradientTest.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
import Control.Monad.IO.Class (liftIO)

import qualified TensorFlow.Core as TF
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput')
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, resizeBilinear', tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt, slice, shape, diag, depthwiseConv2dNative', depthwiseConv2dNativeBackpropInput', batchMatMul, batchMatMul', sum, conjugateTranspose)
import qualified TensorFlow.Gradient as TF
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable, shape)
import qualified TensorFlow.Output as TF
Expand Down Expand Up @@ -596,6 +596,83 @@ transAttrs :: (TF.Attribute a,
transAttrs a b =
(TF.opAttr "transpose_a" .~ a) . (TF.opAttr "transpose_b" .~ b)


batchMatMulGradient :: Test
batchMatMulGradient = testCase "batchMatMulGradients" $ do

let dfBuild = do
x <- TF.render $ TF.zeros $ TF.Shape [2,3, 1 :: Int64]
w <- TF.zeroInitializedVariable $ TF.Shape [2,1, 2 :: Int64]
let f = x `TF.batchMatMul` TF.readValue w :: TF.Tensor TF.Build Float
dfs <- TF.gradients f [x]
return (x, dfs)

(xShape, dxShape) <- TF.runSession $ do
(x, [dx]) <- TF.build dfBuild
TF.run (TF.shape x, TF.shape dx)

assertEqual "Shape of gradient must match shape of input" xShape (dxShape :: V.Vector Int32)


-- test that gradient of batchMatMul can be taken gradient of
batchMatMulGradGrad :: Test
batchMatMulGradGrad = testCase "batchMatMulGradGrad" $ do
let width = 2 :: Int64
height = 3 :: Int64
batch = 4 :: Int64

let tower = do
x <- TF.render $ TF.zeros $ TF.Shape [batch, height, 1]
w <- TF.zeroInitializedVariable $ TF.Shape [batch, 1, width]
let f = x `TF.batchMatMul` TF.readValue w
[dfdx] <- TF.gradients f [x]
let f'x = TF.sum dfdx (TF.vector [1, 2 :: Int32])
[dfdw] <- TF.gradients f'x [w] -- take gradient again (this time over w)
return [TF.readValue w, TF.expr dfdw]

TF.runSession $ do
[w, dfdw] <- TF.build tower
(wShape, dfdwShape) <- TF.run (TF.shape w, TF.shape dfdw)
liftIO $ assertEqual "Shape of gradient must match input" wShape (dfdwShape :: V.Vector Int32)

let step = w `TF.add` dfdw
w0 <- TF.run step
liftIO $ V.fromList [3.0,3.0,3.0,3.0,3.0,3.0,3.0,3.0 :: Float] @=? w0


-- test that gradient of batchMatMul deals correctly with adj_x and adj_y
batchMatMulAdjointGradient :: (Bool, Bool) -> Test
batchMatMulAdjointGradient axw = testCase ("batchMatMulAdjointGradients " ++ show axw) $ do
let (adjX, adjW) = axw

let dfBuild = do
let xShape = TF.Shape [2, 3, 1 :: Int64]
let xZeros = TF.zeros xShape
x <- TF.render $ if adjX then TF.conjugateTranspose xZeros (TF.vector [0, 2, 1 :: Int32]) else xZeros
variable <- TF.zeroInitializedVariable $ TF.Shape [2, 1, 2 :: Int64]
let wv = if adjW then TF.conjugateTranspose (TF.readValue variable) (TF.vector [0, 2, 1 :: Int32]) else TF.readValue variable
let f = TF.batchMatMul' (adjAttrs adjX adjW) x wv :: TF.Tensor TF.Build Float
w <- TF.render wv
ds <- TF.gradients f [x, w]
return (x, w, ds)

TF.runSession $ do
(x, w, [dx, dw]) <- TF.build dfBuild
xShape <- TF.run $ TF.shape x
dxShape <- TF.run $ TF.shape dx
liftIO $ assertEqual "xShape must match dxShape" xShape (dxShape :: V.Vector Int32)

wShape <- TF.run $ TF.shape w
dwShape <- TF.run $ TF.shape dw
liftIO $ assertEqual "wShape must match dwShape" wShape (dwShape :: V.Vector Int32)

adjAttrs :: (TF.Attribute x,
TF.Attribute y) =>
x -> y -> TF.OpDef -> TF.OpDef
adjAttrs x y =
(TF.opAttr "adj_x" .~ x) . (TF.opAttr "adj_y" .~ y)


-- TODO check gradient with regard to filter also
testConv2DBackpropInputGrad :: Test
testConv2DBackpropInputGrad = testCase "testConv2DBackpropInputGrad" $ do
Expand Down Expand Up @@ -699,6 +776,12 @@ main = defaultMain
, matMulTransposeGradient (False, True)
, matMulTransposeGradient (True, False)
, matMulTransposeGradient (True, True)
, batchMatMulGradient
, batchMatMulGradGrad
, batchMatMulAdjointGradient (False, False)
, batchMatMulAdjointGradient (False, True)
, batchMatMulAdjointGradient (True, False)
, batchMatMulAdjointGradient (True, True)
, testConv2DBackpropInputGrad
, testDepthwiseConv2dGrad
, testDepthwiseConv2dBackpropInputGrad
Expand Down