Skip to content

Support running test with reader #11390

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 2, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 85 additions & 1 deletion python/paddle/fluid/layers/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,9 @@ def reset():
return reader


def _copy_reader_var_(block, var):
def _copy_reader_var_(block, var, newname=None):
if newname == None:
newname = var.name
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
new_var.desc.set_shapes(var.desc.shapes())
new_var.desc.set_dtypes(var.desc.dtypes())
Expand Down Expand Up @@ -689,3 +691,85 @@ def load(out, file_path, load_as_fp16=None):
if load_as_fp16 is not None:
attrs['load_as_fp16'] = load_as_fp16
helper.append_op(type="load", inputs={}, output={"Out": out}, args=attrs)


def _is_reader_op(op, block):
if "Out" in op.output_names:
reader_out = block.vars[op.output("Out")[0]]
if reader_out.type == core.VarDesc.VarType.READER:
return True
return False


def get_test_program(filelist, program=None, startup_program=None):
"""
Transpile current train program to a program to read test dataset
if the program is using reader ops like "open_files_op".
"""
if program == None:
program = default_main_program()
if startup_program == None:
startup_program = default_startup_program()

# 1. find out the orignal reader var name
# open_files_var = None
# train_open_files_op = None
startup_reader_op_list = []

for op in startup_program.global_block().ops:
if _is_reader_op(op, startup_program.global_block()):
startup_reader_op_list.append(op)

if len(startup_reader_op_list) == 0:
return program

root_reader_op = startup_reader_op_list[0]

# 2. add operators to startup to read open and read test data files
for op in startup_reader_op_list:
orig_var_name = op.output("Out")[0]
orig_var = startup_program.global_block().vars[orig_var_name]
new_test_var = _copy_reader_var_(
startup_program.global_block(),
orig_var,
newname=orig_var_name + "_test")

# for open_files like operators have no input.
inputs = None
if "UnderlyingReader" in op.input_names:
orig_input_var_name = op.input("UnderlyingReader")[0]
orig_input_var = startup_program.global_block().vars[
orig_input_var_name]
new_input_var = _copy_reader_var_(
startup_program.global_block(),
orig_input_var,
newname=orig_input_var_name + "_test")
inputs = {"UnderlyingReader": new_input_var}
test_op = startup_program.global_block().append_op(
type=op.type,
inputs=inputs,
outputs={'Out': [new_test_var]},
attrs=op.attrs)
# root reader op's filelist attr for read test files
if op.type == root_reader_op.type:
test_op.set_attr("file_names", filelist)
if op.type == "create_multi_pass_reader":
test_op.set_attr("pass_num", 1)

# 3. rename reader vars in inference program to different name
# to avoid read from train data.
origname = root_reader_op.output("Out")[0]
newname = origname + "_test"
program.global_block().rename_var(str(origname), str(newname))
for op in program.global_block().ops:
if _is_reader_op(op, program.global_block()):
origname = op.output("Out")[0]
newname = origname + "_test"
program.global_block().rename_var(str(origname), str(newname))

if op.type == "create_multi_pass_reader":
op.set_attr("pass_num", 1)

program.sync_with_cpp()

return program