Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【auto parallel】剔除切分推导相关的头文件对proto 的依赖 #60543

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions paddle/fluid/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h"

namespace paddle {
namespace distributed {
Expand Down Expand Up @@ -406,14 +407,17 @@ OperatorDistAttrProto OperatorDistAttr::to_proto() const {
for (const auto& item : input_dist_attrs_) {
auto proto_item = proto.mutable_input_dist_attrs()->Add();
proto_item->set_name(item.first);
proto_item->mutable_tensor_dist_attr()->CopyFrom(item.second.to_proto());
proto_item->mutable_tensor_dist_attr()->CopyFrom(
phi::distributed::to_proto(item.second));
}
for (const auto& item : output_dist_attrs_) {
auto proto_item = proto.mutable_output_dist_attrs()->Add();
proto_item->set_name(item.first);
proto_item->mutable_tensor_dist_attr()->CopyFrom(item.second.to_proto());
proto_item->mutable_tensor_dist_attr()->CopyFrom(
phi::distributed::to_proto(item.second));
}
proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto());
proto.mutable_process_mesh()->CopyFrom(
phi::distributed::to_proto(process_mesh_));
proto.set_impl_type(impl_type_);
proto.set_impl_idx(impl_idx_);
proto.set_chunk_id(chunk_id_);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/core/distributed/auto_parallel/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ collect_srcs(
dist_mapper.cc
dist_tensor.cc
dist_meta_tensor.cc
proto_helper.cc
placement_types.cc
inferspmd_utils.cc)

Expand Down
70 changes: 31 additions & 39 deletions paddle/phi/core/distributed/auto_parallel/device_mesh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@ limitations under the License. */
#include <iterator>

#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"

namespace phi {
namespace distributed {
namespace auto_parallel {
Expand All @@ -41,13 +41,11 @@ DeviceCapability DeviceCapability::from_proto(
return capability;
}

DeviceCapabilityProto DeviceCapability::to_proto() const {
DeviceCapabilityProto proto;
proto.set_single_precision_flops(single_precision_flops);
proto.set_double_precision_flops(double_precision_flops);
proto.set_memory_size_in_bytes(memory_size_in_bytes);
proto.set_clock_rate_in_ghz(clock_rate_in_ghz);
return proto;
void DeviceCapability::to_proto(DeviceCapabilityProto *proto) const {
proto->set_single_precision_flops(single_precision_flops);
proto->set_double_precision_flops(double_precision_flops);
proto->set_memory_size_in_bytes(memory_size_in_bytes);
proto->set_clock_rate_in_ghz(clock_rate_in_ghz);
}

std::string Device::to_string() const {
Expand All @@ -69,14 +67,13 @@ Device Device::from_proto(const DeviceProto &proto) {
return device;
}

DeviceProto Device::to_proto() const {
DeviceProto proto;
proto.set_global_id(global_id_);
proto.set_local_id(local_id_);
proto.set_machine_id(machine_id_);
proto.set_type(type_);
proto.mutable_capability()->CopyFrom(capability_.to_proto());
return proto;
void Device::to_proto(DeviceProto *proto) const {
proto->set_global_id(global_id_);
proto->set_local_id(local_id_);
proto->set_machine_id(machine_id_);
proto->set_type(type_);
proto->mutable_capability()->CopyFrom(
phi::distributed::to_proto(capability_));
}

bool operator==(const Device &lhs, const Device &rhs) {
Expand Down Expand Up @@ -109,11 +106,9 @@ LinkCapability LinkCapability::from_proto(const LinkCapabilityProto &proto) {
return capability;
}

LinkCapabilityProto LinkCapability::to_proto() const {
LinkCapabilityProto proto;
proto.set_bandwidth(bandwidth);
proto.set_latency(latency);
return proto;
void LinkCapability::to_proto(LinkCapabilityProto *proto) const {
proto->set_bandwidth(bandwidth);
proto->set_latency(latency);
}

std::string Link::to_string() const {
Expand All @@ -133,13 +128,12 @@ Link Link::from_proto(const LinkProto &proto) {
return link;
}

LinkProto Link::to_proto() const {
LinkProto proto;
proto.set_source_id(source_id_);
proto.set_target_id(target_id_);
proto.set_type(type_);
proto.mutable_capability()->CopyFrom(capability_.to_proto());
return proto;
void Link::to_proto(LinkProto *proto) const {
proto->set_source_id(source_id_);
proto->set_target_id(target_id_);
proto->set_type(type_);
proto->mutable_capability()->CopyFrom(
phi::distributed::to_proto(capability_));
}

bool operator==(const Link &lhs, const Link &rhs) {
Expand Down Expand Up @@ -355,34 +349,32 @@ DeviceMesh DeviceMesh::from_proto(const DeviceMeshProto &proto) {
return mesh;
}

DeviceMeshProto DeviceMesh::to_proto() const {
DeviceMeshProto proto;

proto.set_name(name_);
void DeviceMesh::to_proto(DeviceMeshProto *proto) const {
proto->set_name(name_);

for (const auto &i : shape_) {
proto.add_shape(i);
proto->add_shape(i);
}

for (const auto &i : device_ids_) {
proto.add_device_ids(i);
proto->add_device_ids(i);
}

for (const auto &i : dim_names_) {
proto.add_dim_names(i);
proto->add_dim_names(i);
}

for (const auto &device : devices_) {
proto.mutable_devices()->Add()->CopyFrom(device.second.to_proto());
proto->mutable_devices()->Add()->CopyFrom(
phi::distributed::to_proto(device.second));
}

for (const auto &neighbors : links_) {
for (const auto &link : neighbors.second) {
proto.mutable_links()->Add()->CopyFrom(link.second.to_proto());
proto->mutable_links()->Add()->CopyFrom(
phi::distributed::to_proto(link.second));
}
}

return proto;
}

bool operator==(const DeviceMesh &lhs, const DeviceMesh &rhs) {
Expand Down
17 changes: 12 additions & 5 deletions paddle/phi/core/distributed/auto_parallel/device_mesh.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,13 @@ limitations under the License. */
namespace phi {
namespace distributed {
namespace auto_parallel {

class DeviceCapabilityProto;
class DeviceProto;
class LinkCapabilityProto;
class LinkProto;
class DeviceMeshProto;

struct DeviceCapability {
double single_precision_flops = 0.0;
double double_precision_flops = 0.0;
Expand All @@ -40,7 +47,7 @@ struct DeviceCapability {
std::string to_string() const;

static DeviceCapability from_proto(const DeviceCapabilityProto& proto);
DeviceCapabilityProto to_proto() const;
void to_proto(DeviceCapabilityProto* proto) const;
};

inline std::ostream& operator<<(std::ostream& os, const DeviceCapability& obj) {
Expand Down Expand Up @@ -74,7 +81,7 @@ class Device {
std::string to_string() const;

static Device from_proto(const DeviceProto& proto);
DeviceProto to_proto() const;
void to_proto(DeviceProto* proto) const;

private:
int64_t global_id_;
Expand Down Expand Up @@ -103,7 +110,7 @@ struct LinkCapability {
std::string to_string() const;

static LinkCapability from_proto(const LinkCapabilityProto& proto);
LinkCapabilityProto to_proto() const;
void to_proto(LinkCapabilityProto* proto) const;
};

inline std::ostream& operator<<(std::ostream& os, const LinkCapability& obj) {
Expand Down Expand Up @@ -131,7 +138,7 @@ class Link {
std::string to_string() const;

static Link from_proto(const LinkProto& proto);
LinkProto to_proto() const;
void to_proto(LinkProto* proto) const;

private:
int64_t source_id_;
Expand Down Expand Up @@ -273,7 +280,7 @@ class DeviceMesh {
std::string to_string() const;

static DeviceMesh from_proto(const DeviceMeshProto& proto);
DeviceMeshProto to_proto() const;
void to_proto(DeviceMeshProto* proto) const;

private:
std::string name_;
Expand Down
20 changes: 10 additions & 10 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License. */
#include <iterator>

#include "glog/logging.h"
#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h"

namespace phi {
namespace distributed {
Expand Down Expand Up @@ -308,25 +309,24 @@ void TensorDistAttr::from_proto(const TensorDistAttrProto& proto) {
}
}

TensorDistAttrProto TensorDistAttr::to_proto() const {
TensorDistAttrProto proto;
proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto());
void TensorDistAttr::to_proto(TensorDistAttrProto* proto) const {
proto->mutable_process_mesh()->CopyFrom(
phi::distributed::to_proto(process_mesh_));
for (const auto& i : dims_mapping_) {
proto.add_dims_mapping(i);
proto->add_dims_mapping(i);
}
proto.set_batch_dim(batch_dim_);
proto.set_chunk_id(chunk_id_);
proto->set_batch_dim(batch_dim_);
proto->set_chunk_id(chunk_id_);
for (const auto& i : dynamic_dims_) {
proto.add_dynamic_dims(i);
proto->add_dynamic_dims(i);
}
return proto;
}

std::string TensorDistAttr::serialize_to_string() {
std::string data;
auto proto = to_proto();
auto proto = phi::distributed::to_proto(*this);
proto.SerializeToString(&data);
PADDLE_ENFORCE_EQ(to_proto().SerializeToString(&data),
PADDLE_ENFORCE_EQ(phi::distributed::to_proto(*this).SerializeToString(&data),
true,
errors::InvalidArgument(
"Failed to serialize tensor dist attr to string."));
Expand Down
7 changes: 5 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ limitations under the License. */
#include <vector>

#include "paddle/phi/common/reduce_type.h"
#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"
#include "paddle/phi/core/enforce.h"
Expand All @@ -32,6 +31,10 @@ limitations under the License. */
namespace phi {
namespace distributed {

namespace auto_parallel {
class TensorDistAttrProto;
}

constexpr int kReplicateDim = -1;

class PlacementStatus {
Expand Down Expand Up @@ -169,7 +172,7 @@ class TEST_API TensorDistAttr {
// future partial-support-stage-II.
void from_proto(const auto_parallel::TensorDistAttrProto& proto);

auto_parallel::TensorDistAttrProto to_proto() const;
void to_proto(auto_parallel::TensorDistAttrProto* proto) const;

std::string serialize_to_string();

Expand Down
10 changes: 5 additions & 5 deletions paddle/phi/core/distributed/auto_parallel/dist_mapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#include <algorithm>

#include "paddle/phi/core/distributed/auto_parallel/dist_mapper.h"
#include "paddle/phi/core/distributed/auto_parallel/proto_helper.h"
#include "paddle/phi/core/distributed/auto_parallel/utils.h"

namespace phi {
Expand Down Expand Up @@ -91,20 +92,19 @@ DistributedMapper DistributedMapper::from_proto(
return dist_mapper;
}

DistributedMapperProto DistributedMapper::to_proto() const {
DistributedMapperProto proto;
void DistributedMapper::to_proto(DistributedMapperProto* proto) const {
for (const auto& item : device_meshes_) {
proto.mutable_device_meshes()->Add()->CopyFrom(item.second.to_proto());
proto->mutable_device_meshes()->Add()->CopyFrom(
phi::distributed::to_proto(item.second));
}
for (const auto& outer : process_id_to_device_ids_) {
auto proto_item = proto.mutable_process_id_to_device_ids()->Add();
auto proto_item = proto->mutable_process_id_to_device_ids()->Add();
proto_item->set_process_id(outer.first);
proto_item->set_device_mesh_name(outer.second.first);
for (const auto& inner : outer.second.second) {
proto_item->add_device_ids(inner);
}
}
return proto;
}

std::string DistributedMapper::to_string() const {
Expand Down
5 changes: 3 additions & 2 deletions paddle/phi/core/distributed/auto_parallel/dist_mapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,15 @@ limitations under the License. */

#include <utility>

#include "paddle/phi/core/distributed/auto_parallel/auto_parallel.pb.h"
#include "paddle/phi/core/distributed/auto_parallel/device_mesh.h"
#include "paddle/phi/core/distributed/auto_parallel/process_mesh.h"

namespace phi {
namespace distributed {
namespace auto_parallel {

class DistributedMapperProto;

class DistributedMapper {
public:
DistributedMapper() = default;
Expand Down Expand Up @@ -52,7 +53,7 @@ class DistributedMapper {
std::string to_string() const;

static DistributedMapper from_proto(const DistributedMapperProto& proto);
DistributedMapperProto to_proto() const;
void to_proto(DistributedMapperProto* proto) const;

private:
std::map<std::string, DeviceMesh> device_meshes_;
Expand Down
13 changes: 4 additions & 9 deletions paddle/phi/core/distributed/auto_parallel/process_mesh.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ limitations under the License. */

#include <algorithm>
#include <iterator>

#include "paddle/phi/core/distributed/auto_parallel/utils.h"

namespace phi {
Expand Down Expand Up @@ -105,22 +104,18 @@ ProcessMesh ProcessMesh::from_proto(const ProcessMeshProto &proto) {
return mesh;
}

ProcessMeshProto ProcessMesh::to_proto() const {
ProcessMeshProto proto;

void ProcessMesh::to_proto(ProcessMeshProto *proto) const {
for (const auto &i : shape_) {
proto.add_shape(i);
proto->add_shape(i);
}

for (const auto &i : process_ids_) {
proto.add_process_ids(i);
proto->add_process_ids(i);
}

for (const auto &i : dim_names_) {
proto.add_dim_names(i);
proto->add_dim_names(i);
}

return proto;
}

bool operator==(const ProcessMesh &lhs, const ProcessMesh &rhs) {
Expand Down
Loading