Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions python/paddle/framework/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def _build_saved_state_dict(state_dict):
raise ValueError(
"The saved tensor is not initialized. If you used group sharded, please use save_group_sharded_model."
)
save_dict[key] = np.array(value)
save_dict[key] = np.array(value.cpu())
name_table[key] = value.name
else:
save_dict[key] = value
Expand Down Expand Up @@ -91,7 +91,9 @@ def _load_state_dict_from_save_inference_model(model_path, config):
# 3. construct state_dict
load_param_dict = {}
for var_name in persistable_var_dict:
load_param_dict[var_name] = np.array(persistable_var_dict[var_name])
load_param_dict[var_name] = np.array(
persistable_var_dict[var_name].cpu()
)

# if *.info exists, we can recover structured_name
var_info_filename = str(config.params_filename) + ".info"
Expand Down Expand Up @@ -145,7 +147,7 @@ def _load_state_dict_from_save_params(model_path):
# 3. construct state_dict
load_param_dict = {}
for var in load_var_list:
load_param_dict[var.name] = np.array(var)
load_param_dict[var.name] = np.array(var.cpu())

return load_param_dict

Expand Down Expand Up @@ -290,13 +292,15 @@ def _pickle_save(obj, f, protocol):
)

def reduce_varbase(self):
data = np.array(self)
data = np.array(self.cpu())
name = self.name

return (tuple, ((name, data),))

def reduce_LoDTensor(self):
data = np.array(self)
p = core.Place()
p.set_place(paddle.CPUPlace())
data = np.array(self._copy(p))

return (eval, ('data', {'data': data}))

Expand Down Expand Up @@ -1108,7 +1112,9 @@ def load(path, **configs):
try:
tensor, _ = _load_lod_tensor(path)
if config.return_numpy:
return np.array(tensor)
p = core.Place()
p.set_place(paddle.CPUPlace())
return np.array(tensor._copy(p))
else:
if in_dygraph_mode():
return _lod_tensor2varbase(tensor)
Expand Down