forked from raminmh/CfC
-
Notifications
You must be signed in to change notification settings - Fork 0
/
train_et_smnist.py
74 lines (62 loc) · 2.18 KB
/
train_et_smnist.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import os
import sys
import time
import tensorflow as tf
import argparse
from irregular_sampled_datasets import ETSMnistData
from tf_cfc import CfcCell, MixedCfcCell
import numpy as np
parser = argparse.ArgumentParser()
parser.add_argument("--model", default="cfc")
parser.add_argument("--size", default=64, type=int)
parser.add_argument("--epochs", default=200, type=int)
parser.add_argument("--lr", default=0.0005, type=float)
args = parser.parse_args()
data = ETSMnistData(time_major=False)
CFC_CONFIG = {
"backbone_activation": "gelu",
"backbone_dr": 0.0,
"forget_bias": 3.0,
"backbone_units": 128,
"backbone_layers": 1,
"weight_decay": 0,
"use_lstm": False,
"no_gate": False,
"minimal": False,
}
if args.model == "cfc":
cell = CfcCell(units=args.size, hparams=CFC_CONFIG)
elif args.model == "no_gate":
CFC_CONFIG["no_gate"] = True
cell = CfcCell(units=args.size, hparams=CFC_CONFIG)
elif args.model == "minimal":
CFC_CONFIG["minimal"] = True
cell = CfcCell(units=args.size, hparams=CFC_CONFIG)
elif args.model == "mixed":
cell = MixedCfcCell(units=args.size, hparams=CFC_CONFIG)
else:
raise ValueError("Unknown model type '{}'".format(args.model))
pixel_input = tf.keras.Input(shape=(data.pad_size, 1), name="pixel")
time_input = tf.keras.Input(shape=(data.pad_size, 1), name="time")
mask_input = tf.keras.Input(shape=(data.pad_size,), dtype=tf.bool, name="mask")
rnn = tf.keras.layers.RNN(cell, time_major=False, return_sequences=False)
dense_layer = tf.keras.layers.Dense(10)
output_states = rnn((pixel_input, time_input), mask=mask_input)
y = dense_layer(output_states)
model = tf.keras.Model(inputs=[pixel_input, time_input, mask_input], outputs=[y])
model.compile(
optimizer=tf.keras.optimizers.RMSprop(args.lr),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
)
model.summary()
# Fit and evaluate
hist = model.fit(
x=(data.train_events, data.train_elapsed, data.train_mask),
y=data.train_y,
batch_size=128,
epochs=args.epochs,
)
_, best_test_acc = model.evaluate(
x=(data.test_events, data.test_elapsed, data.test_mask), y=data.test_y
)