Skip to content

Commit

Permalink
Added onnx visualizations for transformer model
Browse files Browse the repository at this point in the history
  • Loading branch information
Fuxxel committed Apr 8, 2020
1 parent 8f5a5ae commit 4065ea1
Show file tree
Hide file tree
Showing 10 changed files with 91 additions and 2 deletions.
11 changes: 9 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def mse(x, y):
def classification_accuracy(input, target):
return (input.argmax(-1) == target).float().mean()

def write_options_to_file(options, path):
def write_options_to_file(options, path, additional_info=None):
print("--------------------")
print("Options:")
with open(path, "w") as out_file:
Expand All @@ -53,6 +53,11 @@ def write_options_to_file(options, path):
if type(value) in [str, int, float, bool]:
out_file.write(f"{name}:{value}\n")
print(f"{name}:{value}")

if additional_info:
for name, value in additional_info.items():
out_file.write(f"{name}:{value}\n")
print(f"{name}:{value}")
print("--------------------")

def index_to_coin(index):
Expand Down Expand Up @@ -85,7 +90,9 @@ def main(args):
model_save_path = os.path.join(options.artifact_dir, timestamp)
print(f"Save path: {model_save_path}")
os.makedirs(model_save_path, exist_ok=False)
write_options_to_file(options, os.path.join(model_save_path, "options.txt"))
write_options_to_file(options, os.path.join(model_save_path, "options.txt"), additional_info={
"num_parameters": sum([p.numel() for p in transformer.parameters()])
})
print("Sampling individual timesteps" if options.sample_individual_timesteps else "Collecting multiple timesteps into embedding vector")

criterion = CrossEntropyLoss()
Expand Down
30 changes: 30 additions & 0 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,36 @@
import torch
from torch.nn import Module, TransformerEncoder, TransformerEncoderLayer, Linear

class TransformerVisual(torch.nn.Module):
def __init__(self, options):
super(TransformerVisual, self).__init__()

assert(type(options) == Options)
self.__options = options

self.encoder_layer = TransformerEncoderLayer(d_model=self.__options.num_input_features,
nhead=self.__options.encoder_number_of_heads,
dim_feedforward=self.__options.encoder_feedforward_dimension,
dropout=self.__options.encoder_dropout,
activation=self.__options.encoder_activation)

self.encoder = TransformerEncoder(encoder_layer=self.encoder_layer,
num_layers=self.__options.num_encoder_layers,
norm=self.__options.norm)
self.src_mask = None

def forward(self, x, return_latent=False):
if self.src_mask is None or self.src_mask.size(0) != len(x):
self.src_mask = self.__generate_square_subsequent_mask(len(x)).to(x.device)

output = self.encoder(x, self.src_mask)

return output

def __generate_square_subsequent_mask(self, size):
mask = (torch.triu(torch.ones(size, size)) == 1).transpose(0, 1)
return mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))

class TransformerClassifier(torch.nn.Module):
def __init__(self, options):
super(TransformerClassifier, self).__init__()
Expand Down
Binary file added transformer_model.onnx
Binary file not shown.
Binary file added transformer_model.onnx.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions transformer_model.onnx.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
5 changes: 5 additions & 0 deletions visualization/onnx_model_options.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
num_encoder_layers:1
encoder_number_of_heads:1
window_size:2048
num_input_features:100
batch_size:16
Binary file added visualization/transformer_model.onnx
Binary file not shown.
Binary file added visualization/transformer_model.onnx.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions visualization/transformer_model.onnx.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
45 changes: 45 additions & 0 deletions visualize_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from model import TransformerVisual
from options import Options

import torch
from torch.nn import MSELoss, CrossEntropyLoss
from torch.optim import AdamW
from torch.utils.data import DataLoader # random_split
from torch._utils import _accumulate
from torch import randperm
import torch.onnx

import numpy as np
import matplotlib.pyplot as plt

import argparse
import os

def update_options_from_args(options, args):
for arg in vars(args):
setattr(options, arg, getattr(args, arg))

def add_options_to_parser(parser):
dummy_options = Options()
for name in dummy_options.get_option_names():
default_value = getattr(dummy_options, name)
if type(default_value) in [str, int, float]:
parser.add_argument("--" + name, type=type(default_value), default=default_value)
elif type(default_value) == bool:
parser.add_argument("--" + name, action="store_false" if default_value else "store_true")

def main(args):
options = Options()
update_options_from_args(options, args)

transformer_model = TransformerVisual(options)
dummy_input = torch.randn(options.window_size, options.batch_size, options.num_input_features)
torch.onnx.export(transformer_model, dummy_input, "transformer_model.onnx")

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Transformer on crossbar')

add_options_to_parser(parser)

args = parser.parse_args()
main(args)

0 comments on commit 4065ea1

Please sign in to comment.