Skip to content

Commit 9150150

Browse files
rikvdkleijfkm3
authored andcommitted
Added support for tanh activation function (tensorflow#223)
1 parent 61e58fd commit 9150150

File tree

3 files changed

+16
-1
lines changed

3 files changed

+16
-1
lines changed

tensorflow-ops/src/TensorFlow/Gradient.hs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ import Lens.Family2 (Lens', view, (&), (^.), (.~), (%~))
4545
import Lens.Family2.State.Strict (uses)
4646
import Lens.Family2.Stock (at, intAt)
4747
import Lens.Family2.Unchecked (lens, iso)
48-
import Prelude hiding (sum)
48+
import Prelude hiding (sum, tanh)
4949
import Text.Printf (printf)
5050
import qualified Data.Graph.Inductive.Basic as FGL
5151
import qualified Data.Graph.Inductive.Graph as FGL
@@ -76,6 +76,8 @@ import TensorFlow.Ops
7676
, matMul'
7777
, reducedShape
7878
, reluGrad
79+
, tanh
80+
, tanhGrad
7981
, reshape
8082
, scalar
8183
, shape
@@ -459,6 +461,7 @@ opGrad "Abs" _ [toT -> x] [dz] = [Just $ expr dz * signum x]
459461
opGrad "Neg" _ [_] [dz] = [Just $ negate $ expr dz]
460462
opGrad "Relu" _ [toT -> x] [dz] = [Just $ reluGrad dz x]
461463
opGrad "ReluGrad" _ [_, toT -> x ] [dz] = [Just $ reluGrad dz x, Just $ CoreOps.zerosLike x]
464+
opGrad "Tanh" _ [toT -> x] [dz] = [Just $ tanhGrad (tanh x) dz]
462465

463466
opGrad "Concat" _ _ix [dy]
464467
-- Concat concatenates input tensors
@@ -833,6 +836,7 @@ numOutputs o =
833836
"SparseSegmentSum" -> 1
834837
"Sub" -> 1
835838
"Sum" -> 1
839+
"Tanh" -> 1
836840
"Tile" -> 1
837841
"Transpose" -> 1
838842
"TruncatedNormal" -> 1

tensorflow-ops/src/TensorFlow/Ops.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,8 @@ module TensorFlow.Ops
112112
, CoreOps.relu'
113113
, CoreOps.reluGrad
114114
, CoreOps.reluGrad'
115+
, CoreOps.tanh
116+
, CoreOps.tanhGrad
115117
, CoreOps.reshape
116118
, CoreOps.reshape'
117119
, restore

tensorflow-ops/tests/GradientTest.hs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,14 @@ testReluGradGrad = testCase "testReluGradGrad" $ do
282282
TF.gradients y' [x] >>= TF.run
283283
V.fromList [0] @=? dx
284284

285+
testTanhGrad :: Test
286+
testTanhGrad = testCase "testTanhGrad" $ do
287+
[dx] <- TF.runSession $ do
288+
x <- TF.render $ TF.vector [0 :: Float]
289+
let y = TF.tanh x
290+
TF.gradients y [x] >>= TF.run
291+
V.fromList [1] @=? dx
292+
285293
testFillGrad :: Test
286294
testFillGrad = testCase "testFillGrad" $ do
287295
[dx] <- TF.runSession $ do
@@ -427,6 +435,7 @@ main = defaultMain
427435
, testMaximumGradGrad
428436
, testReluGrad
429437
, testReluGradGrad
438+
, testTanhGrad
430439
, testFillGrad
431440
, testTileGrad
432441
, testTile2DGrad

0 commit comments

Comments
 (0)