-
Notifications
You must be signed in to change notification settings - Fork 104
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
10cedc2
commit 64bf5fa
Showing
5 changed files
with
267 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from tab_transformer_pytorch.tab_transformer_pytorch import TabTransformer | ||
from tab_transformer_pytorch.ft_transformer import FTTransformer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,219 @@ | ||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn, einsum | ||
|
||
from einops import rearrange, repeat | ||
|
||
# helpers | ||
|
||
def exists(val): | ||
return val is not None | ||
|
||
def default(val, d): | ||
return val if exists(val) else d | ||
|
||
# classes | ||
|
||
class Residual(nn.Module): | ||
def __init__(self, fn): | ||
super().__init__() | ||
self.fn = fn | ||
|
||
def forward(self, x, **kwargs): | ||
return self.fn(x, **kwargs) + x | ||
|
||
# feedforward and attention | ||
|
||
class GEGLU(nn.Module): | ||
def forward(self, x): | ||
x, gates = x.chunk(2, dim = -1) | ||
return x * F.gelu(gates) | ||
|
||
def FeedForward(dim, mult = 4, dropout = 0.): | ||
return nn.Sequential( | ||
nn.LayerNorm(dim), | ||
nn.Linear(dim, dim * mult * 2), | ||
GEGLU(), | ||
nn.Dropout(dropout), | ||
nn.Linear(dim * mult, dim) | ||
) | ||
|
||
class Attention(nn.Module): | ||
def __init__( | ||
self, | ||
dim, | ||
heads = 8, | ||
dim_head = 64, | ||
dropout = 0. | ||
): | ||
super().__init__() | ||
inner_dim = dim_head * heads | ||
self.heads = heads | ||
self.scale = dim_head ** -0.5 | ||
|
||
self.norm = nn.LayerNorm(dim) | ||
|
||
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) | ||
self.to_out = nn.Linear(inner_dim, dim, bias = False) | ||
|
||
self.dropout = nn.Dropout(dropout) | ||
|
||
def forward(self, x): | ||
h = self.heads | ||
|
||
x = self.norm(x) | ||
|
||
q, k, v = self.to_qkv(x).chunk(3, dim = -1) | ||
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v)) | ||
q = q * self.scale | ||
|
||
sim = einsum('b h i d, b h j d -> b h i j', q, k) | ||
|
||
attn = sim.softmax(dim = -1) | ||
attn = self.dropout(attn) | ||
|
||
out = einsum('b h i j, b h j d -> b h i d', attn, v) | ||
out = rearrange(out, 'b h n d -> b n (h d)', h = h) | ||
return self.to_out(out) | ||
|
||
# transformer | ||
|
||
class Transformer(nn.Module): | ||
def __init__( | ||
self, | ||
dim, | ||
depth, | ||
heads, | ||
dim_head, | ||
attn_dropout, | ||
ff_dropout | ||
): | ||
super().__init__() | ||
self.layers = nn.ModuleList([]) | ||
|
||
for _ in range(depth): | ||
self.layers.append(nn.ModuleList([ | ||
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout), | ||
FeedForward(dim, dropout = ff_dropout), | ||
])) | ||
|
||
def forward(self, x): | ||
for attn, ff in self.layers: | ||
x = attn(x) + x | ||
x = ff(x) + x | ||
|
||
return x | ||
|
||
# numerical embedder | ||
|
||
class NumericalEmbedder(nn.Module): | ||
def __init__(self, dim, num_numerical_types): | ||
super().__init__() | ||
self.weights = nn.Parameter(torch.randn(num_numerical_types, dim)) | ||
self.biases = nn.Parameter(torch.randn(num_numerical_types, dim)) | ||
|
||
def forward(self, x): | ||
x = rearrange(x, 'b n -> b n 1') | ||
return x * self.weights + self.biases | ||
|
||
# main class | ||
|
||
class FTTransformer(nn.Module): | ||
def __init__( | ||
self, | ||
*, | ||
categories, | ||
num_continuous, | ||
dim, | ||
depth, | ||
heads, | ||
dim_head = 16, | ||
dim_out = 1, | ||
mlp_hidden_mults = (4, 2), | ||
mlp_act = None, | ||
num_special_tokens = 2, | ||
attn_dropout = 0., | ||
ff_dropout = 0. | ||
): | ||
super().__init__() | ||
assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive' | ||
|
||
# categories related calculations | ||
|
||
self.num_categories = len(categories) | ||
self.num_unique_categories = sum(categories) | ||
|
||
# create category embeddings table | ||
|
||
self.num_special_tokens = num_special_tokens | ||
total_tokens = self.num_unique_categories + num_special_tokens | ||
|
||
# for automatically offsetting unique category ids to the correct position in the categories embedding table | ||
|
||
categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens) | ||
categories_offset = categories_offset.cumsum(dim = -1)[:-1] | ||
self.register_buffer('categories_offset', categories_offset) | ||
|
||
# categorical embedding | ||
|
||
self.categorical_embeds = nn.Embedding(total_tokens, dim) | ||
|
||
# continuous | ||
|
||
self.numerical_embedder = NumericalEmbedder(dim, num_continuous) | ||
|
||
# cls token | ||
|
||
self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) | ||
|
||
# transformer | ||
|
||
self.transformer = Transformer( | ||
dim = dim, | ||
depth = depth, | ||
heads = heads, | ||
dim_head = dim_head, | ||
attn_dropout = attn_dropout, | ||
ff_dropout = ff_dropout | ||
) | ||
|
||
# to logits | ||
|
||
self.to_logits = nn.Sequential( | ||
nn.LayerNorm(dim), | ||
nn.ReLU(), | ||
nn.Linear(dim, dim_out) | ||
) | ||
|
||
def forward(self, x_categ, x_numer): | ||
b = x_categ.shape[0] | ||
|
||
assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input' | ||
x_categ += self.categories_offset | ||
|
||
x_categ = self.categorical_embeds(x_categ) | ||
|
||
# add numerically embedded tokens | ||
|
||
x_numer = self.numerical_embedder(x_numer) | ||
|
||
# concat categorical and numerical | ||
|
||
x = torch.cat((x_categ, x_numer), dim = 1) | ||
|
||
# append cls tokens | ||
|
||
cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b) | ||
x = torch.cat((cls_tokens, x), dim = 1) | ||
|
||
# attend | ||
|
||
x = self.transformer(x) | ||
|
||
# get cls token | ||
|
||
x = x[:, 0] | ||
|
||
# out in the paper is linear(relu(ln(cls))) | ||
|
||
return self.to_logits(x) |