Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions paddle/fluid/eager/grad_tensor_holder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,11 +89,11 @@ void GradTensorHolder::CopyValueFromTensor(size_t slot_id,
auto init_grad =
paddle::experimental::full(t.shape(), 1, t.dtype(), t.place());
auto global_dense_t =
static_cast<phi::DenseTensor*>(init_grad.impl().get());
std::static_pointer_cast<phi::DenseTensor>(init_grad.impl());
auto dist_t =
static_cast<phi::distributed::DistTensor*>(t.impl().get());
init_grad.set_impl(std::make_shared<phi::distributed::DistTensor>(
*global_dense_t, dist_t->dist_attr()));
global_dense_t, dist_t->dist_attr()));
buffer_[slot_id][rank] = init_grad;
} else {
PADDLE_THROW(paddle::platform::errors::Fatal(
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/pybind/eager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ void CreateDistTensorWithNumpyValue(TensorObject* self,
"CPUPlace/XPUPlace/CUDAPlace/CUDAPinnedPlace/CustomPlace"));
}

auto dist_tensor =
std::make_shared<phi::distributed::DistTensor>(dense_tensor, dist_attr);
auto dist_tensor = std::make_shared<phi::distributed::DistTensor>(
std::make_shared<phi::DenseTensor>(dense_tensor), dist_attr);
self->tensor.set_impl(dist_tensor);

if (!autograd_meta->GetMutableGradNode()) {
Expand Down Expand Up @@ -280,13 +280,13 @@ void InitDistTensorWithTensor(TensorObject* self,
if (place == src.place()) {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(src.impl());
self->tensor.set_impl(std::make_shared<DistTensor>(*tensor, dist_attr));
self->tensor.set_impl(std::make_shared<DistTensor>(tensor, dist_attr));
VLOG(4) << "Same place, do ShareDataWith for DistTensor.";
} else {
std::shared_ptr<phi::DenseTensor> tensor =
std::static_pointer_cast<phi::DenseTensor>(
src.copy_to(place, true).impl());
self->tensor.set_impl(std::make_shared<DistTensor>(*tensor, dist_attr));
self->tensor.set_impl(std::make_shared<DistTensor>(tensor, dist_attr));
VLOG(4) << "Different place, do TensorCopy for DistTensor.";
}
if (src.get_autograd_meta()) {
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2079,9 +2079,9 @@ void DistTensorConverter::convert(Tensor* x) {
phi::distributed::TensorDistAttr dist_attr(
phi::vectorize(x->impl()->dims()));
dist_attr.set_process_mesh(*mesh);
auto dense_t = static_cast<phi::DenseTensor*>(x->impl().get());
auto dense_t = std::static_pointer_cast<phi::DenseTensor>(x->impl());
x->set_impl(
std::make_shared<phi::distributed::DistTensor>(*dense_t, dist_attr));
std::make_shared<phi::distributed::DistTensor>(dense_t, dist_attr));
}
}

Expand Down
12 changes: 2 additions & 10 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -566,11 +566,7 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
if (tmp) {
// TODO(GhostScreaming): now all dist case are nullptr
if (tmp->impl() == nullptr) {
phi::DenseTensor dense_t;
// TODO(GhostScreaming): polish code, dist_attr is null now
phi::distributed::TensorDistAttr dist_attr;
auto dist_t =
std::make_shared<phi::distributed::DistTensor>(dense_t, dist_attr);
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
tmp->set_impl(dist_t);
}
result.emplace_back(
Expand All @@ -587,11 +583,7 @@ std::vector<phi::distributed::DistTensor*> SetKernelDistOutput(
out->reserve(out_size);
std::vector<phi::distributed::DistTensor*> results(out_size);
for (size_t i = 0; i < out_size; ++i) {
phi::DenseTensor dense_t;
// TODO(GhostScreaming): polish code, dist_attr is null now
phi::distributed::TensorDistAttr dist_attr;
auto dist_t =
std::make_shared<phi::distributed::DistTensor>(dense_t, dist_attr);
auto dist_t = std::make_shared<phi::distributed::DistTensor>();
results[i] = dist_t.get();
out->emplace_back();
out->back().set_impl(dist_t);
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,8 @@ PrepareDataForDistTensor(const std::vector<Tensor>& input,
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
out.push_back(std::make_shared<phi::distributed::DistTensor>(
trans_in_tensor, dist_tensor->dist_attr()));
std::make_shared<phi::DenseTensor>(trans_in_tensor),
dist_tensor->dist_attr()));
}
} else {
out.push_back(nullptr);
Expand Down
73 changes: 43 additions & 30 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,35 +33,45 @@ inline void check_defined(const DistTensor& dist_tensor,
method_hint));
}

DistTensor::DistTensor(const phi::DenseTensor& global_value,
DistTensor::DistTensor() : value_(std::make_shared<DenseTensor>()) {}

DistTensor::DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DenseTensor初始化DistTensor,给的DenseTensor一定要是global的吗?

const TensorDistAttr& dist_attr)
: dims_(global_value.dims()), dist_attr_(dist_attr), value_(global_value) {
// TODO(liyurui): This is a temporary solution. We need to support only infer
// meta when the input dense_tensor is empty.
// Support the value in DistTensor only has DenseTensor meta
// but without actual data. So we can visit its meta attr even if it is
// undefined.
: dims_(global_value->dims()),
dist_attr_(dist_attr),
value_(std::make_shared<DenseTensor>()) {
// If the current rank doesn't in process_mesh, we should create an
// uninitialized tensor only with tensor_meta.
if (IsCurRankInMesh(dist_attr.process_mesh())) {
if (value_.initialized() && !dist_attr.is_replicated()) {
if (!dist_attr.is_replicated()) {
// 1. create replicated global tensor
int64_t dims_size = global_value.dims().size();
std::vector<int64_t> dims_mapping(dims_size, -1);
dist_attr_.set_dims_mapping(dims_mapping);
if (dist_attr_.is_partial()) {
dist_attr_.clean_partial_status();
}
dist_attr_.set_dims_mapping(dims_mapping);
TensorDistAttr replicated_dist_attr(vectorize(global_value->dims()));
replicated_dist_attr.set_process_mesh(dist_attr.process_mesh());
DistTensor replicated_tensor(global_value, replicated_dist_attr);

// 2. reshard from replicated to other state
auto* func = ChooseProperReshardFunction(*this, dist_attr);
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value.place());
func->Eval(dev_ctx, *this, dist_attr, this);
auto* func = ChooseProperReshardFunction(replicated_tensor, dist_attr);
auto* dev_ctx = DeviceContextPool::Instance().Get(global_value->place());
func->Eval(dev_ctx, replicated_tensor, dist_attr, this);
} else {
value_ = global_value;
}
} else {
// TODO(liyurui): The following logic is illegal, and should be removed
// later. It exist temporary because the basic execution procedure is not
// ready, even sometimes we try to construct a DistTensor with empty
// DistAttr. Here we warning when the DistAttr is empty for debug use.
if (dist_attr.empty()) {
LOG(WARNING) << "Try to construct a dist tensor with empty dist attr.";
}
value_ = global_value;
}
}

DistTensor::DistTensor(const DDim& dims, const TensorDistAttr& dist_attr)
: dims_(dims), dist_attr_(dist_attr) {}
: dims_(dims),
dist_attr_(dist_attr),
value_(std::make_shared<DenseTensor>()) {}

void DistTensor::unsafe_set_dims(const DDim& dims) {
if (this->initialized()) {
Expand All @@ -80,39 +90,42 @@ void DistTensor::unsafe_set_dist_attr(const TensorDistAttr& dist_attr) {
}

int64_t DistTensor::numel() const {
check_defined(*this, "numel");
return value_.numel();
// DistTensor with uninitialized local tensor can
// also have numel.
return product(dims_);
}

const DDim& DistTensor::local_dims() const {
check_defined(*this, "local_dims");
return value_.dims();
return value_->dims();
}

bool DistTensor::valid() const {
check_defined(*this, "valid");
return value_.valid();
return value_->valid();
}

bool DistTensor::defined() const { return value_.holder_ != nullptr; }
bool DistTensor::defined() const { return value_->holder_ != nullptr; }

bool DistTensor::initialized() const {
return value_.holder_ != nullptr && value_.holder_->ptr();
return value_->holder_ != nullptr && value_->holder_->ptr();
}

DataType DistTensor::dtype() const {
check_defined(*this, "dtype");
return value_.dtype();
// DistTensor with uninitialized local tensor can
// also have dtype.
return value_->dtype();
}

DataLayout DistTensor::layout() const {
check_defined(*this, "layout");
return value_.layout();
// DistTensor with uninitialized local tensor can
// also have layout.
return value_->layout();
}

const Place& DistTensor::place() const {
check_defined(*this, "place");
return value_.holder_->place();
return value_->holder_->place();
}

void* DistTensor::AllocateFrom(Allocator* allocator,
Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/core/distributed/auto_parallel/dist_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ class DistTensor final
/// \brief Careful to create dist tensor using default constructor.
/// this should only used in reshard for now, and the dist properties
/// will be set by reshard later.
DistTensor() = default;
DistTensor();

/// \brief Construct a dist tensor based dense tensor.
/// \param global_value The global dense tensor of the current tensor.
/// \param dist_attr The distributed attributes of the current tensor.
DistTensor(const phi::DenseTensor& global_value,
DistTensor(const std::shared_ptr<phi::DenseTensor>& global_value,
const TensorDistAttr& dist_attr);

/// \brief Construct a empty dist tensor (for infer spmd)
Expand Down Expand Up @@ -68,7 +68,7 @@ class DistTensor final

/// \brief Returns the dense tensor value's const reference in dist tensor.
/// \return The DenseTensor value's const reference
const DenseTensor& value() const { return value_; }
const DenseTensor& value() const { return *value_; }

/// \brief Returns the mutable dense tensor value in dist tensor.
/// \note If DenseTensor value is modified externally, the corresponding
Expand All @@ -77,7 +77,7 @@ class DistTensor final
/// so you need to make sure to consider it thoroughly when using
/// this method.
/// \return The mutable pointer of DenseTensor value
DenseTensor* unsafe_mutable_value() { return &value_; }
DenseTensor* unsafe_mutable_value() { return value_.get(); }

/// \brief Returns the global dims of the dist tensor.
/// \return The global dims of the dist tensor.
Expand Down Expand Up @@ -126,7 +126,7 @@ class DistTensor final
// The distributed attributes
TensorDistAttr dist_attr_;
// The local DenseTensor value
DenseTensor value_;
std::shared_ptr<DenseTensor> value_;
};

} // namespace distributed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ ProcessMesh GetSubProcessMesh(const ProcessMesh& mesh, int64_t axis) {
for (int64_t j = coord.size() - 2; j >= 0; --j) {
rank += coord[j] * mesh.dim_size(j + 1);
}
process_ids.emplace_back(rank);
process_ids.emplace_back(mesh.process_ids()[rank]);
}

ProcessMesh out_mesh(shape, process_ids, dim_names);
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/reshard_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ std::shared_ptr<DistTensor> ReshardFunction::Eval(
}

void ReshardFunction::SetValue(DistTensor* tensor, const DenseTensor& value) {
tensor->value_ = value;
tensor->value_ = std::make_shared<DenseTensor>(value);
}

void ReshardFunction::SetDistProps(DistTensor* tensor,
Expand All @@ -56,7 +56,7 @@ void ReshardFunction::SetDistProps(DistTensor* tensor,
}

DenseTensor* ReshardFunction::GetMutableTensor(DistTensor* tensor) {
return &tensor->value_;
return tensor->value_.get();
}

ReshardFunction* ChooseProperReshardFunction(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx,
for (const auto& iter : p2p_pair) {
int64_t src = iter.first;
int64_t dst = iter.second;
VLOG(3) << "Send/Recv from src " << src << " to dst " << dst;
if (src == cur_global_rank) {
VLOG(3) << "Send from src " << src << " to dst " << dst;
int64_t dst_local_rank = GetLocalRankInParticipate(all_process_ids, dst);
// Sice send kernel only has input, so we don't need to infermeta
// actually. According to this reason, just use the kernel directly.
Expand All @@ -102,6 +102,7 @@ void SameStatusReshardFunction::Eval(phi::DeviceContext* dev_ctx,
dst_local_rank,
dynamic_shape);
} else if (dst == cur_global_rank) {
VLOG(3) << "Recv from src " << src << " to dst " << dst;
int64_t src_local_rank = GetLocalRankInParticipate(all_process_ids, src);
RESHARD_FUNCTOR_WITH_COMM(dev_ctx,
PRecv,
Expand Down
2 changes: 1 addition & 1 deletion test/cpp/auto_parallel/dist_tensor_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ TEST(dist_tensor, constructor) {
dist_attr.set_process_mesh(mesh);

// copy construct
DenseTensor x1(alloc, meta);
std::shared_ptr<DenseTensor> x1 = std::make_shared<DenseTensor>(alloc, meta);
DistTensor dist_x1(x1, dist_attr);
EXPECT_TRUE(dist_x1.defined());
EXPECT_TRUE(dist_x1.initialized());
Expand Down