File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -111,7 +111,8 @@ def restore_ckpt(model,
111
111
if tf .io .gfile .isdir (ckpt_path_or_file ):
112
112
ckpt_path_or_file = tf .train .latest_checkpoint (ckpt_path_or_file )
113
113
114
- var_shape_map = tf .train .load_checkpoint (ckpt_path_or_file ).get_variable_to_shape_map ()
114
+ reader = tf .train .load_checkpoint (ckpt_path_or_file )
115
+ var_shape_map = reader .get_variable_to_shape_map ()
115
116
if '_CHECKPOINTABLE_OBJECT_GRAPH' in var_shape_map :
116
117
model .load_weights (ckpt_path_or_file )
117
118
else :
@@ -141,7 +142,7 @@ def restore_ckpt(model,
141
142
else :
142
143
raise ValueError (msg )
143
144
else :
144
- var .assign (tf . train . load_variable ( ckpt_path_or_file , key ))
145
+ var .assign (reader . get_tensor ( key ), read_value = False )
145
146
if i < 10 :
146
147
logging .info ('Init %s from %s (%s)' , var .name , key , ckpt_path_or_file )
147
148
else :
You can’t perform that action at this time.
0 commit comments