Skip to content

Commit c811037

Browse files
rschlotterbeckfkm3
authored andcommitted
Add gradient for sigmoid (tensorflow#245)
1 parent 1fbd5d4 commit c811037

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

tensorflow-ops/src/TensorFlow/Gradient.hs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ import TensorFlow.Ops
8585
, shape
8686
, softmaxCrossEntropyWithLogits
8787
, sum
88+
, sigmoid
89+
, sigmoidGrad
8890
, scalarize
8991
, vector
9092
, zerosLike
@@ -481,6 +483,7 @@ opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
481483
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
482484
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
483485
opGrad "Tanh" _ [toT -> x] [dz] = [Just $ tanhGrad (tanh x) dz]
486+
opGrad "Sigmoid" _ [toT -> x] [dz] = [Just $ sigmoidGrad (sigmoid x) dz]
484487

485488
opGrad "Concat" _ _ix [dy]
486489
-- Concat concatenates input tensors
@@ -947,6 +950,7 @@ numOutputs o =
947950
"ReluGrad" -> 1
948951
"Reshape" -> 1
949952
"Select" -> 1
953+
"Sigmoid" -> 1
950954
"Size" -> 1
951955
"Slice" -> 1
952956
"SoftmaxCrossEntropyWithLogits" -> 2

tensorflow-ops/src/TensorFlow/Ops.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ module TensorFlow.Ops
123123
, scalar'
124124
, shape
125125
, shape'
126+
, CoreOps.sigmoid
127+
, CoreOps.sigmoidGrad
126128
, CoreOps.sign
127129
, CoreOps.sign'
128130
, CoreOps.size

tensorflow-ops/tests/GradientTest.hs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,14 @@ testTanhGrad = testCase "testTanhGrad" $ do
368368
TF.gradients y [x] >>= TF.run
369369
V.fromList [1] @=? dx
370370

371+
testSigmoidGrad :: Test
372+
testSigmoidGrad = testCase "testSigmoidGrad" $ do
373+
[dx] <- TF.runSession $ do
374+
x <- TF.render $ TF.vector [0 :: Float]
375+
let y = TF.sigmoid x
376+
TF.gradients y [x] >>= TF.run
377+
V.fromList [0.25] @=? dx
378+
371379
testExpandDims :: Test
372380
testExpandDims =
373381
testCase "testExpandDims" $ do
@@ -681,6 +689,7 @@ main = defaultMain
681689
, testReluGrad
682690
, testReluGradGrad
683691
, testTanhGrad
692+
, testSigmoidGrad
684693
, testExpandDims
685694
, testReshape
686695
, testPad

0 commit comments

Comments
 (0)