Skip to content

Commit 2c30d50

Browse files
committed
Fix sample weights generation for validation data
1 parent 7a86ff7 commit 2c30d50

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

keras/models.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -450,15 +450,17 @@ def fit(self, X, y, batch_size=128, nb_epoch=100, verbose=1, callbacks=[],
450450
if validation_data:
451451
if len(validation_data) == 2:
452452
X_val, y_val = validation_data
453+
X_val = standardize_X(X_val)
454+
y_val = standardize_y(y_val)
453455
sample_weight_val = np.ones(y_val.shape[:-1] + (1,))
454456
elif len(validation_data) == 3:
455457
X_val, y_val, sample_weight_val = validation_data
458+
X_val = standardize_X(X_val)
459+
y_val = standardize_y(y_val)
460+
sample_weight_val = standardize_weights(y_val, sample_weight=sample_weight_val)
456461
else:
457462
raise Exception("Invalid format for validation data; provide a tuple (X_val, y_val) or (X_val, y_val, sample_weight). \
458463
X_val may be a numpy array or a list of numpy arrays depending on your model input.")
459-
X_val = standardize_X(X_val)
460-
y_val = standardize_y(y_val)
461-
sample_weight_val = standardize_weights(y_val, sample_weight=sample_weight_val)
462464
val_ins = X_val + [y_val, sample_weight_val]
463465

464466
elif 0 < validation_split < 1:

0 commit comments

Comments
 (0)