Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions tools/extra/summarize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
#!/usr/bin/env python

"""Net summarization tool.

This tool summarizes the structure of a net in a concise but comprehensive
tabular listing, taking a prototxt file as input.

Use this tool to check at a glance that the computation you've specified is the
computation you expect.
"""

from caffe.proto import caffe_pb2
from google import protobuf
import re
import argparse

# ANSI codes for coloring blobs (used cyclically)
COLORS = ['92', '93', '94', '95', '97', '96', '42', '43;30', '100',
'444', '103;30', '107;30']
DISCONNECTED_COLOR = '41'

def read_net(filename):
net = caffe_pb2.NetParameter()
with open(filename) as f:
protobuf.text_format.Parse(f.read(), net)
return net

def format_param(param):
out = []
if len(param.name) > 0:
out.append(param.name)
if param.lr_mult != 1:
out.append('x{}'.format(param.lr_mult))
if param.decay_mult != 1:
out.append('Dx{}'.format(param.decay_mult))
return ' '.join(out)

def printed_len(s):
return len(re.sub(r'\033\[[\d;]+m', '', s))

def print_table(table, max_width):
"""Print a simple nicely-aligned table.

table must be a list of (equal-length) lists. Columns are space-separated,
and as narrow as possible, but no wider than max_width. Text may overflow
columns; note that unlike string.format, this will not affect subsequent
columns, if possible."""

max_widths = [max_width] * len(table[0])
column_widths = [max(printed_len(row[j]) + 1 for row in table)
for j in range(len(table[0]))]
column_widths = [min(w, max_w) for w, max_w in zip(column_widths, max_widths)]

for row in table:
row_str = ''
right_col = 0
for cell, width in zip(row, column_widths):
right_col += width
row_str += cell + ' '
row_str += ' ' * max(right_col - printed_len(row_str), 0)
print row_str

def summarize_net(net):
disconnected_tops = set()
for lr in net.layer:
disconnected_tops |= set(lr.top)
disconnected_tops -= set(lr.bottom)

table = []
colors = {}
for lr in net.layer:
tops = []
for ind, top in enumerate(lr.top):
color = colors.setdefault(top, COLORS[len(colors) % len(COLORS)])
if top in disconnected_tops:
top = '\033[1;4m' + top
if len(lr.loss_weight) > 0:
top = '{} * {}'.format(lr.loss_weight[ind], top)
tops.append('\033[{}m{}\033[0m'.format(color, top))
top_str = ', '.join(tops)

bottoms = []
for bottom in lr.bottom:
color = colors.get(bottom, DISCONNECTED_COLOR)
bottoms.append('\033[{}m{}\033[0m'.format(color, bottom))
bottom_str = ', '.join(bottoms)

if lr.type == 'Python':
type_str = lr.python_param.module + '.' + lr.python_param.layer
else:
type_str = lr.type

# Summarize conv/pool parameters.
# TODO support rectangular/ND parameters
conv_param = lr.convolution_param
if (lr.type in ['Convolution', 'Deconvolution']
and len(conv_param.kernel_size) == 1):
arg_str = str(conv_param.kernel_size[0])
if len(conv_param.stride) > 0 and conv_param.stride[0] != 1:
arg_str += '/' + str(conv_param.stride[0])
if len(conv_param.pad) > 0 and conv_param.pad[0] != 0:
arg_str += '+' + str(conv_param.pad[0])
arg_str += ' ' + str(conv_param.num_output)
if conv_param.group != 1:
arg_str += '/' + str(conv_param.group)
elif lr.type == 'Pooling':
arg_str = str(lr.pooling_param.kernel_size)
if lr.pooling_param.stride != 1:
arg_str += '/' + str(lr.pooling_param.stride)
if lr.pooling_param.pad != 0:
arg_str += '+' + str(lr.pooling_param.pad)
else:
arg_str = ''

if len(lr.param) > 0:
param_strs = map(format_param, lr.param)
if max(map(len, param_strs)) > 0:
param_str = '({})'.format(', '.join(param_strs))
else:
param_str = ''
else:
param_str = ''

table.append([lr.name, type_str, param_str, bottom_str, '->', top_str,
arg_str])
return table

def main():
parser = argparse.ArgumentParser(description="Print a concise summary of net computation.")
parser.add_argument('filename', help='net prototxt file to summarize')
parser.add_argument('-w', '--max-width', help='maximum field width',
type=int, default=30)
args = parser.parse_args()

net = read_net(args.filename)
table = summarize_net(net)
print_table(table, max_width=args.max_width)

if __name__ == '__main__':
main()