Skip to content

Commit 8ee1430

Browse files
authored
Fixed weighted mse
1 parent 68d1cbf commit 8ee1430

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

_Dist/NeuralNetworks/NNUtil.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ class Losses:
141141
def mse(y, pred, _, weights=None):
142142
if weights is None:
143143
return tf.losses.mean_squared_error(y, pred)
144-
return tf.losses.mean_squared_error(y, pred, weights)
144+
return tf.losses.mean_squared_error(y, pred, tf.reshape(weights, [-1, 1]))
145145

146146
@staticmethod
147147
def cross_entropy(y, pred, already_prob, weights=None):

0 commit comments

Comments
 (0)