Skip to content

Commit

Permalink
add FT Transformer from Yandex
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Nov 1, 2022
1 parent 10cedc2 commit 64bf5fa
Show file tree
Hide file tree
Showing 5 changed files with 267 additions and 8 deletions.
53 changes: 46 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,36 @@ model = TabTransformer(
x_categ = torch.randint(0, 5, (1, 5)) # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_cont = torch.randn(1, 10) # assume continuous values are already normalized individually

pred = model(x_categ, x_cont)
pred = model(x_categ, x_cont) # (1, 1)
```

## FT Transformer

<img src="./tabs-vs-ft.png" width="500px"></img>

<a href="https://arxiv.org/abs/2106.11959v2">This paper</a> from Yandex improves on Tab Transformer by using a simpler scheme for embedding the continuous numerical values as shown in the diagram above, courtesy of <a href="https://www.reddit.com/r/MachineLearning/comments/yhdqlj/project_improving_deep_learning_for_tabular_data/">this reddit post</a>.

Included in this repository just for convenient comparison to Tab Transformer

```python
import torch
from tab_transformer_pytorch import FTTransformer

model = FTTransformer(
categories = (10, 5, 6, 5, 8), # tuple containing the number of unique values within each category
num_continuous = 10, # number of continuous values
dim = 32, # dimension, paper set at 32
dim_out = 1, # binary prediction, but could be anything
depth = 6, # depth, paper recommended 6
heads = 8, # heads, paper recommends 8
attn_dropout = 0.1, # post-attention dropout
ff_dropout = 0.1 # feed forward dropout
)

x_categ = torch.randint(0, 5, (1, 5)) # category values, from 0 - max number of categories, in the order as passed into the constructor above
x_numer = torch.randn(1, 10) # numerical value

pred = model(x_categ, x_numer) # (1, 1)
```

## Unsupervised Training
Expand All @@ -51,11 +80,21 @@ To undergo the type of unsupervised training described in the paper, you can fir

```bibtex
@misc{huang2020tabtransformer,
title={TabTransformer: Tabular Data Modeling Using Contextual Embeddings},
author={Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
year={2020},
eprint={2012.06678},
archivePrefix={arXiv},
primaryClass={cs.LG}
title = {TabTransformer: Tabular Data Modeling Using Contextual Embeddings},
author = {Xin Huang and Ashish Khetan and Milan Cvitkovic and Zohar Karnin},
year = {2020},
eprint = {2012.06678},
archivePrefix = {arXiv},
primaryClass = {cs.LG}
}
```

```bibtex
@article{Gorishniy2021RevisitingDL,
title = {Revisiting Deep Learning Models for Tabular Data},
author = {Yu. V. Gorishniy and Ivan Rubachev and Valentin Khrulkov and Artem Babenko},
journal = {ArXiv},
year = {2021},
volume = {abs/2106.11959}
}
```
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'tab-transformer-pytorch',
packages = find_packages(),
version = '0.1.4',
version = '0.2.0',
license='MIT',
description = 'Tab Transformer - Pytorch',
author = 'Phil Wang',
Expand Down
Binary file added tab-vs-ft.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions tab_transformer_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from tab_transformer_pytorch.tab_transformer_pytorch import TabTransformer
from tab_transformer_pytorch.ft_transformer import FTTransformer
219 changes: 219 additions & 0 deletions tab_transformer_pytorch/ft_transformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,219 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum

from einops import rearrange, repeat

# helpers

def exists(val):
return val is not None

def default(val, d):
return val if exists(val) else d

# classes

class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn

def forward(self, x, **kwargs):
return self.fn(x, **kwargs) + x

# feedforward and attention

class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
return x * F.gelu(gates)

def FeedForward(dim, mult = 4, dropout = 0.):
return nn.Sequential(
nn.LayerNorm(dim),
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim)
)

class Attention(nn.Module):
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.
):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.scale = dim_head ** -0.5

self.norm = nn.LayerNorm(dim)

self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
self.to_out = nn.Linear(inner_dim, dim, bias = False)

self.dropout = nn.Dropout(dropout)

def forward(self, x):
h = self.heads

x = self.norm(x)

q, k, v = self.to_qkv(x).chunk(3, dim = -1)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
q = q * self.scale

sim = einsum('b h i d, b h j d -> b h i j', q, k)

attn = sim.softmax(dim = -1)
attn = self.dropout(attn)

out = einsum('b h i j, b h j d -> b h i d', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)', h = h)
return self.to_out(out)

# transformer

class Transformer(nn.Module):
def __init__(
self,
dim,
depth,
heads,
dim_head,
attn_dropout,
ff_dropout
):
super().__init__()
self.layers = nn.ModuleList([])

for _ in range(depth):
self.layers.append(nn.ModuleList([
Attention(dim, heads = heads, dim_head = dim_head, dropout = attn_dropout),
FeedForward(dim, dropout = ff_dropout),
]))

def forward(self, x):
for attn, ff in self.layers:
x = attn(x) + x
x = ff(x) + x

return x

# numerical embedder

class NumericalEmbedder(nn.Module):
def __init__(self, dim, num_numerical_types):
super().__init__()
self.weights = nn.Parameter(torch.randn(num_numerical_types, dim))
self.biases = nn.Parameter(torch.randn(num_numerical_types, dim))

def forward(self, x):
x = rearrange(x, 'b n -> b n 1')
return x * self.weights + self.biases

# main class

class FTTransformer(nn.Module):
def __init__(
self,
*,
categories,
num_continuous,
dim,
depth,
heads,
dim_head = 16,
dim_out = 1,
mlp_hidden_mults = (4, 2),
mlp_act = None,
num_special_tokens = 2,
attn_dropout = 0.,
ff_dropout = 0.
):
super().__init__()
assert all(map(lambda n: n > 0, categories)), 'number of each category must be positive'

# categories related calculations

self.num_categories = len(categories)
self.num_unique_categories = sum(categories)

# create category embeddings table

self.num_special_tokens = num_special_tokens
total_tokens = self.num_unique_categories + num_special_tokens

# for automatically offsetting unique category ids to the correct position in the categories embedding table

categories_offset = F.pad(torch.tensor(list(categories)), (1, 0), value = num_special_tokens)
categories_offset = categories_offset.cumsum(dim = -1)[:-1]
self.register_buffer('categories_offset', categories_offset)

# categorical embedding

self.categorical_embeds = nn.Embedding(total_tokens, dim)

# continuous

self.numerical_embedder = NumericalEmbedder(dim, num_continuous)

# cls token

self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

# transformer

self.transformer = Transformer(
dim = dim,
depth = depth,
heads = heads,
dim_head = dim_head,
attn_dropout = attn_dropout,
ff_dropout = ff_dropout
)

# to logits

self.to_logits = nn.Sequential(
nn.LayerNorm(dim),
nn.ReLU(),
nn.Linear(dim, dim_out)
)

def forward(self, x_categ, x_numer):
b = x_categ.shape[0]

assert x_categ.shape[-1] == self.num_categories, f'you must pass in {self.num_categories} values for your categories input'
x_categ += self.categories_offset

x_categ = self.categorical_embeds(x_categ)

# add numerically embedded tokens

x_numer = self.numerical_embedder(x_numer)

# concat categorical and numerical

x = torch.cat((x_categ, x_numer), dim = 1)

# append cls tokens

cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)

# attend

x = self.transformer(x)

# get cls token

x = x[:, 0]

# out in the paper is linear(relu(ln(cls)))

return self.to_logits(x)

0 comments on commit 64bf5fa

Please sign in to comment.