Skip to content

Commit

Permalink
Take network size in TFProcess constructor.
Browse files Browse the repository at this point in the history
* Take network size in TFProcess constructor.
* Make number of blocks and filters required in call to parse.py.

Pull request leela-zero#2361.
  • Loading branch information
TFiFiE authored and gcp committed May 3, 2019
1 parent bf17515 commit db5569c
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 30 deletions.
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -318,13 +318,13 @@ This requires a working installation of TensorFlow 1.4 or later:
src/leelaz -w weights.txt
dump_supervised bigsgf.sgf train.out
exit
training/tf/parse.py train.out
training/tf/parse.py 6 128 train.out

This will run and regularly dump Leela Zero weight files to disk, as
well as snapshots of the learning state numbered by the batch number.
If interrupted, training can be resumed with:
This will run and regularly dump Leela Zero weight files (of networks with 6
blocks and 128 filters) to disk, as well as snapshots of the learning state
numbered by the batch number. If interrupted, training can be resumed with:

training/tf/parse.py train.out leelaz-model-batchnumber
training/tf/parse.py 6 128 train.out leelaz-model-batchnumber

# Todo

Expand Down
8 changes: 1 addition & 7 deletions training/tf/net_to_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,8 @@
blocks //= 8
print("Blocks", blocks)

tfprocess = TFProcess()
tfprocess = TFProcess(blocks, channels)
tfprocess.init(batch_size=1, gpus_num=1)
if tfprocess.RESIDUAL_BLOCKS != blocks:
raise ValueError("Number of blocks in tensorflow model doesn't match "\
"number of blocks in input network")
if tfprocess.RESIDUAL_FILTERS != channels:
raise ValueError("Number of filters in tensorflow model doesn't match "\
"number of filters in input network")
tfprocess.replace_weights(weights)
path = os.path.join(os.getcwd(), "leelaz-model")
save_path = tfprocess.saver.save(tfprocess.session, path, global_step=0)
20 changes: 17 additions & 3 deletions training/tf/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,24 +107,38 @@ def split_chunks(chunks, test_ratio):
def main():
parser = argparse.ArgumentParser(
description='Train network from game data.')
parser.add_argument("blockspref",
help="Number of blocks", nargs='?', type=int)
parser.add_argument("filterspref",
help="Number of filters", nargs='?', type=int)
parser.add_argument("trainpref",
help='Training file prefix', nargs='?', type=str)
parser.add_argument("restorepref",
help='Training snapshot prefix', nargs='?', type=str)
parser.add_argument("--blocks", '-b',
help="Number of blocks", type=int)
parser.add_argument("--filters", '-f',
help="Number of filters", type=int)
parser.add_argument("--train", '-t',
help="Training file prefix", type=str)
parser.add_argument("--test", help="Test file prefix", type=str)
parser.add_argument("--restore", type=str,
help="Prefix of tensorflow snapshot to restore from")
parser.add_argument("--logbase", default='leelalogs', type=str,
help="Log file prefix (for tensorboard)")
help="Log file prefix (for tensorboard) (default: %(default)s)")
parser.add_argument("--sample", default=DOWN_SAMPLE, type=int,
help="Rate of data down-sampling to use")
help="Rate of data down-sampling to use (default: %(default)d)")
args = parser.parse_args()

blocks = args.blocks or args.blockspref
filters = args.filters or args.filterspref
train_data_prefix = args.train or args.trainpref
restore_prefix = args.restore or args.restorepref

if not blocks or not filters:
print("Must supply number of blocks and filters")
return

training = get_chunks(train_data_prefix)
if not args.test:
# Generate test by taking 10% of the training chunks.
Expand All @@ -150,7 +164,7 @@ def main():
sample=args.sample,
batch_size=RAM_BATCH_SIZE).parse()

tfprocess = TFProcess()
tfprocess = TFProcess(blocks, filters)
tfprocess.init(RAM_BATCH_SIZE,
logbase=args.logbase,
macrobatch=BATCH_SIZE // RAM_BATCH_SIZE)
Expand Down
30 changes: 15 additions & 15 deletions training/tf/tfprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,10 +111,10 @@ def elapsed(self):
return e

class TFProcess:
def __init__(self):
def __init__(self, residual_blocks, residual_filters):
# Network structure
self.RESIDUAL_FILTERS = 128
self.RESIDUAL_BLOCKS = 6
self.residual_blocks = residual_blocks
self.residual_filters = residual_filters

# model type: full precision (fp32) or mixed precision (fp16)
self.model_dtype = tf.float32
Expand Down Expand Up @@ -602,17 +602,17 @@ def construct_net(self, planes):
# Input convolution
flow = self.conv_block(x_planes, filter_size=3,
input_channels=18,
output_channels=self.RESIDUAL_FILTERS,
output_channels=self.residual_filters,
name="first_conv")
# Residual tower
for i in range(0, self.RESIDUAL_BLOCKS):
for i in range(0, self.residual_blocks):
block_name = "res_" + str(i)
flow = self.residual_block(flow, self.RESIDUAL_FILTERS,
flow = self.residual_block(flow, self.residual_filters,
name=block_name)

# Policy head
conv_pol = self.conv_block(flow, filter_size=1,
input_channels=self.RESIDUAL_FILTERS,
input_channels=self.residual_filters,
output_channels=2,
name="policy_head")
h_conv_pol_flat = tf.reshape(conv_pol, [-1, 2 * 19 * 19])
Expand All @@ -624,7 +624,7 @@ def construct_net(self, planes):

# Value head
conv_val = self.conv_block(flow, filter_size=1,
input_channels=self.RESIDUAL_FILTERS,
input_channels=self.residual_filters,
output_channels=1,
name="value_head")
h_conv_val_flat = tf.reshape(conv_val, [-1, 19 * 19])
Expand Down Expand Up @@ -707,21 +707,21 @@ def gen_block(size, f_in, f_out):

class TFProcessTest(unittest.TestCase):
def test_can_replace_weights(self):
tfprocess = TFProcess()
tfprocess = TFProcess(6, 128)
tfprocess.init(batch_size=1)
# use known data to test replace_weights() works.
data = gen_block(3, 18, tfprocess.RESIDUAL_FILTERS) # input conv
for _ in range(tfprocess.RESIDUAL_BLOCKS):
data = gen_block(3, 18, tfprocess.residual_filters) # input conv
for _ in range(tfprocess.residual_blocks):
data.extend(gen_block(3,
tfprocess.RESIDUAL_FILTERS, tfprocess.RESIDUAL_FILTERS))
tfprocess.residual_filters, tfprocess.residual_filters))
data.extend(gen_block(3,
tfprocess.RESIDUAL_FILTERS, tfprocess.RESIDUAL_FILTERS))
tfprocess.residual_filters, tfprocess.residual_filters))
# policy
data.extend(gen_block(1, tfprocess.RESIDUAL_FILTERS, 2))
data.extend(gen_block(1, tfprocess.residual_filters, 2))
data.append([0.4] * 2*19*19 * (19*19+1))
data.append([0.5] * (19*19+1))
# value
data.extend(gen_block(1, tfprocess.RESIDUAL_FILTERS, 1))
data.extend(gen_block(1, tfprocess.residual_filters, 1))
data.append([0.6] * 19*19 * 256)
data.append([0.7] * 256)
data.append([0.8] * 256)
Expand Down

0 comments on commit db5569c

Please sign in to comment.