Skip to content

Commit

Permalink
Update get_validation_data
Browse files Browse the repository at this point in the history
  • Loading branch information
QueensGambit committed Aug 5, 2024
1 parent aafca5c commit 8b5b1d2
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion DeepCrazyhouse/src/training/train_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def main():

update_train_config_via_args(args, train_config)

val_data, x_val, _ = get_validation_data(train_config)
val_data, x_val = get_validation_data(train_config)
input_shape = x_val[0].shape
fill_train_config(train_config, x_val)

Expand Down
2 changes: 1 addition & 1 deletion DeepCrazyhouse/src/training/train_cli_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def get_validation_data(train_config: TrainConfig):
"""
pgn_dataset_arrays_dict = load_pgn_dataset(dataset_type='val', part_id=0, verbose=True, normalize=train_config.normalize)
val_data = get_data_loader(pgn_dataset_arrays_dict, train_config, shuffle=False)
return val_data, x_val, yp_val
return val_data, pgn_dataset_arrays_dict["x"]


def print_model_summary(input_shape: tuple, model, x_val) -> None:
Expand Down

0 comments on commit 8b5b1d2

Please sign in to comment.