|
47 | 47 | " return iterator.get_next()\n",
|
48 | 48 | " return input_fn\n",
|
49 | 49 | "\n",
|
50 |
| - "padding_info = ({'image':[28,28,1],'label':[]})\n", |
| 50 | + "padding_info = ({'image':[784,],'label':[]})\n", |
51 | 51 | "test_input_fn = input_fn_maker('mnist_tfrecord/test/', 'mnist_tfrecord/data_info.csv',batch_size = 512,\n",
|
52 | 52 | " padding = padding_info)\n",
|
53 | 53 | "train_input_fn = input_fn_maker('mnist_tfrecord/train/', 'mnist_tfrecord/data_info.csv', shuffle=True, batch_size = 128,\n",
|
|
63 | 63 | "outputs": [],
|
64 | 64 | "source": [
|
65 | 65 | "def model_fn(features, mode):\n",
|
| 66 | + " features['image'] = tf.reshape(features['image'],[-1,28,28,1])\n", |
66 | 67 | " # shape: [None,28,28,1]\n",
|
67 | 68 | " # create RNN cells:\n",
|
68 |
| - " rnn_fcells = [tf.nn.rnn_cell.GRUCell(dim) for dim in [128,256]]\n", |
69 |
| - " rnn_bcells = [tf.nn.rnn_cell.GRUCell(dim) for dim in [128,256]]\n", |
| 69 | + " rnn_fcells = [tf.nn.rnn_cell.GRUCell(dim,kernel_initializer=tf.orthogonal_initializer) for dim in [128,256]]\n", |
| 70 | + " rnn_bcells = [tf.nn.rnn_cell.GRUCell(dim,kernel_initializer=tf.orthogonal_initializer) for dim in [128,256]]\n", |
70 | 71 | " # stack cells for multi-layers RNN\n",
|
71 | 72 | " multi_rnn_fcell = tf.nn.rnn_cell.MultiRNNCell(rnn_fcells)\n",
|
72 | 73 | " multi_rnn_bcell = tf.nn.rnn_cell.MultiRNNCell(rnn_bcells)\n",
|
|
0 commit comments