Skip to content

Commit

Permalink
fix ut
Browse files Browse the repository at this point in the history
  • Loading branch information
zoooo0820 committed Dec 29, 2023
1 parent c751c82 commit f15e3e5
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/slice_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -609,9 +609,9 @@ static paddle::Tensor dealWithValues(const paddle::Tensor& tensor,
tensor.dtype() == phi::DataType::INT16 ||
tensor.dtype() == phi::DataType::INT8 ||
tensor.dtype() == phi::DataType::UINT8) {
values->push_back(value_obj_tmp.cast<int32_t>());
values->push_back(value_obj_tmp.cast<float>());
} else if (tensor.dtype() == phi::DataType::INT64) {
values->push_back(value_obj_tmp.cast<int64_t>());
values->push_back(value_obj_tmp.cast<double>());
} else if (tensor.dtype() == phi::DataType::BOOL) {
values->push_back(value_obj_tmp.cast<bool>());
} else if (tensor.dtype() == phi::DataType::COMPLEX64) {
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/base/dygraph/tensor_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -975,7 +975,7 @@ def __array__(self, dtype=None):
array = array.astype(dtype)
return array

def pre_deal_index(self, item, value=None):
def pre_deal_index(self, item):
# since in pybind there is no effiency way to transfer Py_Tuple/Py_List/Py_Range to Tensor
# we call this function in python level.
item = list(item) if isinstance(item, tuple) else [item]
Expand All @@ -985,14 +985,14 @@ def pre_deal_index(self, item, value=None):
elif isinstance(slice_item, range):
item[i] = paddle.to_tensor(list(slice_item))

return tuple(item), value
return tuple(item)

def __getitem__(self, item):
item, _ = pre_deal_index(self, item)
item = pre_deal_index(self, item)
return self._getitem_dygraph(item)

def __setitem__(self, item, value):
item, value = pre_deal_index(self, item, value)
item = pre_deal_index(self, item)
return self._setitem_dygraph(item, value)

@framework.dygraph_only
Expand Down

0 comments on commit f15e3e5

Please sign in to comment.