Skip to content

Commit

Permalink
Use THCUNN backward kernels for Tanh and Sigmoid in Autograd (pytorch…
Browse files Browse the repository at this point in the history
  • Loading branch information
apaszke authored and soumith committed Apr 29, 2017
1 parent a071ccb commit 457d78a
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
13 changes: 11 additions & 2 deletions torch/autograd/_functions/pointwise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from itertools import repeat

from ..._thnn import type2backend
from ..function import Function, InplaceFunction


Expand Down Expand Up @@ -51,7 +52,11 @@ def forward(self, i):

def backward(self, grad_output):
result, = self.saved_tensors
return grad_output * (1 - result * result)
grad_input = grad_output.new()
backend = type2backend[type(result)]
backend.Tanh_updateGradInput(backend.library_state, None, grad_output,
grad_input, result)
return grad_input


class Sigmoid(InplaceFunction):
Expand All @@ -67,7 +72,11 @@ def forward(self, i):

def backward(self, grad_output):
result, = self.saved_tensors
return grad_output * ((1 - result) * result)
grad_input = grad_output.new()
backend = type2backend[type(result)]
backend.Sigmoid_updateGradInput(backend.library_state, None, grad_output,
grad_input, result)
return grad_input


class Sinh(Function):
Expand Down
7 changes: 4 additions & 3 deletions torch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from . import _functions
from .modules import utils
from ._functions.padding import ConstantPad2d
from ..autograd import _functions as _autograd_functions
from .modules.utils import _single, _pair, _triple

# Convolutions
Expand Down Expand Up @@ -407,7 +408,7 @@ def hardshrink(input, lambd=0.5):


def tanhshrink(input):
return input - _functions.thnn.Tanh()(input)
return input - _autograd_functions.Tanh()(input)


def softsign(input):
Expand Down Expand Up @@ -435,11 +436,11 @@ def log_softmax(input):


def tanh(input):
return _functions.thnn.Tanh()(input)
return _autograd_functions.Tanh()(input)


def sigmoid(input):
return _functions.thnn.Sigmoid()(input)
return _autograd_functions.Sigmoid()(input)


# etc.
Expand Down

0 comments on commit 457d78a

Please sign in to comment.