Skip to content

Commit

Permalink
Merge pull request #17 from lukovnikov/master
Browse files Browse the repository at this point in the history
activation function in BERTIntermediate
  • Loading branch information
thomwolf authored Nov 13, 2018
2 parents 5cd8d7a + 470076e commit 8513741
Showing 1 changed file with 11 additions and 2 deletions.
13 changes: 11 additions & 2 deletions modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
from six import string_types

def gelu(x):
"""Implementation of the gelu activation function.
Expand All @@ -34,6 +35,13 @@ def gelu(x):
return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))


def swish(x):
return x * torch.sigmoid(x)


ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}


class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`.
"""
Expand All @@ -60,7 +68,7 @@ def __init__(self,
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler.
encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
Expand Down Expand Up @@ -237,7 +245,8 @@ class BERTIntermediate(nn.Module):
def __init__(self, config):
super(BERTIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = gelu
self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, string_types) else config.hidden_act

def forward(self, hidden_states):
hidden_states = self.dense(hidden_states)
Expand Down

0 comments on commit 8513741

Please sign in to comment.