We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 07254e3 commit 316a3bcCopy full SHA for 316a3bc
_Dist/NeuralNetworks/e_AdvancedNN/NN.py
@@ -97,7 +97,7 @@ def init_model_structure_settings(self):
97
def _get_embedding(self, i, n):
98
embedding_size = math.ceil(math.log2(n)) + 1 if self.embedding_size == "log" else self.embedding_size
99
embedding = tf.Variable(tf.truncated_normal(
100
- [1, embedding_size], mean=0, stddev=0.02
+ [n, embedding_size], mean=0, stddev=0.02
101
), name="Embedding{}".format(i))
102
return tf.nn.embedding_lookup(embedding, self._categorical_xs[i], name="Embedded_X{}".format(i))
103
0 commit comments