Skip to content

Commit 666dce9

Browse files
erikaborfkm3
authored andcommitted
Add gradient for sqrt function (#236)
1 parent 896a0d3 commit 666dce9

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

tensorflow-ops/src/TensorFlow/Gradient.hs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -813,6 +813,10 @@ opGrad "Placeholder" _ _ _ = []
813813
opGrad "VarHandleOp" _ _ _ = []
814814
opGrad "Variable" _ _ _ = []
815815

816+
opGrad "Sqrt" _ [toT -> x] [dz] = [Just $ sq' `CoreOps.mul` dz]
817+
where
818+
sq' = scalar 1 `CoreOps.div` (scalar 2 `CoreOps.mul` CoreOps.sqrt x)
819+
816820
opGrad n nodeDef ins grads =
817821
error $ "no gradient implemented for " ++
818822
show (n, length ins, length grads, showMessage nodeDef, ins)
@@ -863,6 +867,7 @@ numOutputs o =
863867
"SparseSegmentSum" -> 1
864868
"Square" -> 1
865869
"Squeeze" -> 1
870+
"Sqrt" -> 1
866871
"Sub" -> 1
867872
"Sum" -> 1
868873
"Tanh" -> 1

tensorflow-ops/tests/GradientTest.hs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ import Control.Monad(forM_, replicateM, zipWithM)
3232
import Control.Monad.IO.Class (liftIO)
3333

3434
import qualified TensorFlow.Core as TF
35-
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze)
35+
import qualified TensorFlow.GenOps.Core as TF (conv2DBackpropInput', max, maximum, tile, pad, batchToSpaceND, spaceToBatchND, squeeze, sqrt)
3636
import qualified TensorFlow.Gradient as TF
3737
import qualified TensorFlow.Ops as TF hiding (zeroInitializedVariable)
3838
import qualified TensorFlow.Output as TF
@@ -324,6 +324,14 @@ testPad =
324324
V.fromList [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1] @=? dx
325325
V.fromList [2, 2, 3] @=? s
326326

327+
testSqrt :: Test
328+
testSqrt = testCase "testSqrt" $ do
329+
[dx] <- TF.runSession $ do
330+
x <- TF.render $ TF.vector [0.0625 :: Float]
331+
let y = TF.sqrt x
332+
TF.gradients y [x] >>= TF.run
333+
V.fromList [2] @=? dx
334+
327335
testBatchToSpaceND :: Test
328336
testBatchToSpaceND =
329337
testCase "testBatchToSpaceND" $ do
@@ -517,6 +525,7 @@ main = defaultMain
517525
, testExpandDims
518526
, testReshape
519527
, testPad
528+
, testSqrt
520529
, testBatchToSpaceND
521530
, testSpaceToBatchND
522531
, testSqueeze
@@ -530,4 +539,4 @@ main = defaultMain
530539
, matMulTransposeGradient (True, False)
531540
, matMulTransposeGradient (True, True)
532541
, testConv2DBackpropInputGrad
533-
]
542+
]

0 commit comments

Comments
 (0)