Skip to content
This repository has been archived by the owner on Jun 16, 2021. It is now read-only.

Commit

Permalink
Move preprocess code to main.py
Browse files Browse the repository at this point in the history
  • Loading branch information
brilee committed Jan 24, 2017
1 parent fd005b2 commit 4e25b00
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 25 deletions.
24 changes: 0 additions & 24 deletions load_data_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import os
import struct
import sys

from features import bulk_extract_features
import go
Expand Down Expand Up @@ -124,26 +123,3 @@ def read(filename):
next_moves = flat_nextmoves.reshape(data_size, board_size * board_size)

return DataSet(pos_features, next_moves, [], is_test=is_test)

def process_raw_data(*dataset_dirs, processed_dir="processed_data"):
sgf_files = list(find_sgf_files(*dataset_dirs))
print("%s sgfs found." % len(sgf_files), file=sys.stderr)
est_num_positions = len(sgf_files) * 200 # about 200 moves per game
print("Estimated number of chunks: %s" % (est_num_positions // CHUNK_SIZE), file=sys.stderr)
positions_w_context = itertools.chain(*map(get_positions_from_sgf, sgf_files))

test_chunk, training_chunks = split_test_training(positions_w_context, est_num_positions)
print("Allocating %s positions as test; remainder as training" % len(test_chunk), file=sys.stderr)

print("Writing test chunk")
test_dataset = DataSet.from_positions_w_context(test_chunk, is_test=True)
test_filename = os.path.join(processed_dir, "test.chunk.gz")
test_dataset.write(test_filename)

training_datasets = map(DataSet.from_positions_w_context, training_chunks)
for i, train_dataset in enumerate(training_datasets):
if i % 10 == 0:
print("Writing training chunk %s" % i)
train_filename = os.path.join(processed_dir, "train%s.chunk.gz" % i)
train_dataset.write(train_filename)
print("%s chunks written" % (i+1))
22 changes: 21 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,27 @@ def preprocess(*data_sets, processed_dir="processed_data"):
if not os.path.isdir(processed_dir):
os.mkdir(processed_dir)

process_raw_data(*data_sets, processed_dir=processed_dir)
sgf_files = list(find_sgf_files(*data_sets))
print("%s sgfs found." % len(sgf_files), file=sys.stderr)
est_num_positions = len(sgf_files) * 200 # about 200 moves per game
print("Estimated number of chunks: %s" % (est_num_positions // CHUNK_SIZE), file=sys.stderr)
positions_w_context = itertools.chain(*map(get_positions_from_sgf, sgf_files))

test_chunk, training_chunks = split_test_training(positions_w_context, est_num_positions)
print("Allocating %s positions as test; remainder as training" % len(test_chunk), file=sys.stderr)

print("Writing test chunk")
test_dataset = DataSet.from_positions_w_context(test_chunk, is_test=True)
test_filename = os.path.join(processed_dir, "test.chunk.gz")
test_dataset.write(test_filename)

training_datasets = map(DataSet.from_positions_w_context, training_chunks)
for i, train_dataset in enumerate(training_datasets):
if i % 10 == 0:
print("Writing training chunk %s" % i)
train_filename = os.path.join(processed_dir, "train%s.chunk.gz" % i)
train_dataset.write(train_filename)
print("%s chunks written" % (i+1))

def train(processed_dir, read_file=None, save_file=None, epochs=10, logdir=None, checkpoint_freq=10000):
test_dataset = DataSet.read(os.path.join(processed_dir, "test.chunk.gz"))
Expand Down

0 comments on commit 4e25b00

Please sign in to comment.