Skip to content

Commit bb4a83f

Browse files
committed
Upload _Dist/NeuralNetworks/i_CNN
1 parent f4f8037 commit bb4a83f

File tree

1 file changed

+81
-0
lines changed
  • _Dist/NeuralNetworks/i_CNN

1 file changed

+81
-0
lines changed

_Dist/NeuralNetworks/i_CNN/CNN.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
import os
2+
import sys
3+
root_path = os.path.abspath("../../../")
4+
if root_path not in sys.path:
5+
sys.path.append(root_path)
6+
7+
import numpy as np
8+
import tensorflow as tf
9+
10+
from _Dist.NeuralNetworks.Base import Generator4d
11+
from _Dist.NeuralNetworks.h_RNN.RNN import Basic3d
12+
from _Dist.NeuralNetworks.NNUtil import Activations
13+
14+
15+
class Basic4d(Basic3d):
16+
def _calculate(self, x, y=None, weights=None, tensor=None, n_elem=1e7, is_training=False):
17+
return super(Basic4d, self)._calculate(x, y, weights, tensor, n_elem / 10, is_training)
18+
19+
20+
class CNN(Basic4d):
21+
def __init__(self, *args, **kwargs):
22+
self.height, self.width = kwargs.pop("height", None), kwargs.pop("width", None)
23+
24+
super(CNN, self).__init__(*args, **kwargs)
25+
self._name_appendix = "CNN"
26+
self._generator_base = Generator4d
27+
28+
self.conv_activations = None
29+
self.n_filters = self.filter_sizes = self.poolings = None
30+
31+
def init_model_param_settings(self):
32+
super(CNN, self).init_model_param_settings()
33+
self.conv_activations = self.model_param_settings.get("conv_activations", "relu")
34+
35+
def init_model_structure_settings(self):
36+
super(CNN, self).init_model_structure_settings()
37+
self.n_filters = self.model_structure_settings.get("n_filters", [32, 32])
38+
self.filter_sizes = self.model_structure_settings.get("filter_sizes", [(3, 3), (3, 3)])
39+
self.poolings = self.model_structure_settings.get("poolings", [None, "max_pool"])
40+
if not len(self.filter_sizes) == len(self.poolings) == len(self.n_filters):
41+
raise ValueError("Length of filter_sizes, n_filters & pooling should be the same")
42+
if isinstance(self.conv_activations, str):
43+
self.conv_activations = [self.conv_activations] * len(self.filter_sizes)
44+
45+
def init_from_data(self, x, y, x_test, y_test, sample_weights, names):
46+
if self.height is None or self.width is None:
47+
assert len(x.shape) == 4, "height and width are not provided, hence len(x.shape) should be 4"
48+
self.height, self.width = x.shape[1:3]
49+
if len(x.shape) == 2:
50+
x = x.reshape(len(x), self.height, self.width, -1)
51+
else:
52+
assert self.height == x.shape[1], "height is set to be {}, but {} found".format(self.height, x.shape[1])
53+
assert self.width == x.shape[2], "width is set to be {}, but {} found".format(self.height, x.shape[2])
54+
if len(x_test.shape) == 2:
55+
x_test = x_test.reshape(len(x_test), self.height, self.width, -1)
56+
super(CNN, self).init_from_data(x, y, x_test, y_test, sample_weights, names)
57+
58+
def _define_input_and_placeholder(self):
59+
self._is_training = tf.placeholder(tf.bool, name="is_training")
60+
self._tfx = tf.placeholder(tf.float32, [None, self.height, self.width, self.n_dim], name="X")
61+
self._tfy = tf.placeholder(tf.float32, [None, self.n_class], name="Y")
62+
63+
def _build_model(self, net=None):
64+
self._model_built = True
65+
if net is None:
66+
net = self._tfx
67+
for i, (filter_size, n_filter, pooling) in enumerate(zip(
68+
self.filter_sizes, self.n_filters, self.poolings
69+
)):
70+
net = tf.layers.conv2d(net, n_filter, filter_size, padding="same")
71+
net = tf.layers.batch_normalization(net, training=self._is_training)
72+
activation = self.conv_activations[i]
73+
if activation is not None:
74+
net = getattr(Activations, activation)(net, activation)
75+
net = tf.layers.dropout(net, training=self._is_training)
76+
if pooling is not None:
77+
net = tf.layers.max_pooling2d(net, 2, 2, name="pool")
78+
79+
fc_shape = np.prod([net.shape[i].value for i in range(1, 4)])
80+
net = tf.reshape(net, [-1, fc_shape])
81+
super(CNN, self)._build_model(net)

0 commit comments

Comments
 (0)