@@ -63,4 +63,53 @@ pir::Value reshard(const pir::Value& x,
63
63
return reshard_op.result (0 );
64
64
}
65
65
66
+ std::vector<pir::Value> local_tensors_from_dist (
67
+ const pir::Value& input,
68
+ const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
69
+ const std::vector<int64_t >& local_dims_mapping,
70
+ const flat_hash_map<int64_t , phi::ReduceType>& local_partial_status,
71
+ const phi::distributed::ProcessMesh& global_mesh,
72
+ const std::vector<int64_t >& global_dims_mapping,
73
+ const flat_hash_map<int64_t , phi::ReduceType>& global_partial_status) {
74
+ pir::IrContext* ctx = pir::IrContext::Instance ();
75
+ std::vector<TensorDistAttribute> local_dist_attrs;
76
+ for (const phi::distributed::ProcessMesh& mesh : local_mesh_list) {
77
+ local_dist_attrs.emplace_back (TensorDistAttribute::get (
78
+ ctx, mesh, local_dims_mapping, local_partial_status));
79
+ }
80
+ TensorDistAttribute global_dist_attr = TensorDistAttribute::get (
81
+ ctx, global_mesh, global_dims_mapping, global_partial_status);
82
+
83
+ auto op = ApiBuilder::Instance ().GetBuilder ()->Build <LocalTensorsFromDistOp>(
84
+ input, local_dist_attrs, global_dist_attr);
85
+ return op.results ();
86
+ }
87
+
88
+ pir::Value dist_tensor_from_locals (
89
+ const std::vector<pir::Value>& inputs,
90
+ const std::vector<phi::distributed::ProcessMesh>& local_mesh_list,
91
+ const std::vector<int64_t >& local_dims_mapping,
92
+ const flat_hash_map<int64_t , phi::ReduceType>& local_partial_status,
93
+ const phi::distributed::ProcessMesh& global_mesh,
94
+ const std::vector<int64_t >& global_dims_mapping,
95
+ const flat_hash_map<int64_t , phi::ReduceType>& global_partial_status,
96
+ const std::vector<int64_t >& global_shape) {
97
+ pir::IrContext* ctx = pir::IrContext::Instance ();
98
+
99
+ std::vector<TensorDistAttribute> local_dist_attrs;
100
+ for (const phi::distributed::ProcessMesh& mesh : local_mesh_list) {
101
+ local_dist_attrs.emplace_back (TensorDistAttribute::get (
102
+ ctx, mesh, local_dims_mapping, local_partial_status));
103
+ }
104
+
105
+ TensorDistAttribute global_dist_attr = TensorDistAttribute::get (
106
+ ctx, global_mesh, global_dims_mapping, global_partial_status);
107
+
108
+ phi::DDim global_ddim = phi::make_ddim (global_shape);
109
+
110
+ auto op = ApiBuilder::Instance ().GetBuilder ()->Build <DistTensorFromLocalsOp>(
111
+ inputs, local_dist_attrs, global_dist_attr, global_ddim);
112
+ return op.result (0 );
113
+ }
114
+
66
115
} // namespace paddle::dialect
0 commit comments