89
89
90
90
// shardy
91
91
#include " shardy/dialect/sdy/ir/dialect.h"
92
+ #include " shardy/dialect/sdy/transforms/passes.h"
92
93
#include " shardy/integrations/c/attributes.h"
93
94
#include " xla/pjrt/mlir_to_hlo.h"
94
95
#include " xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.h"
96
+ #include " xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_export.h"
95
97
#include " xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h"
96
98
97
99
// IFRT
@@ -780,7 +782,8 @@ extern "C" void FutureAwait(FutureType *Future) { Future->Await(); }
780
782
xla::CompileOptions GenerateCompileOptions (int64_t device_id, bool is_sharded,
781
783
const int64_t *mesh_ids,
782
784
int64_t num_mesh_ids,
783
- const char *xla_gpu_cuda_data_dir) {
785
+ const char *xla_gpu_cuda_data_dir,
786
+ bool use_shardy_partitioner) {
784
787
xla::CompileOptions options;
785
788
options.executable_build_options .mutable_debug_options ()
786
789
->set_xla_gpu_cuda_data_dir (xla_gpu_cuda_data_dir);
@@ -792,7 +795,8 @@ xla::CompileOptions GenerateCompileOptions(int64_t device_id, bool is_sharded,
792
795
options.executable_build_options .set_num_partitions (num_mesh_ids);
793
796
794
797
options.executable_build_options .set_use_spmd_partitioning (true );
795
- options.executable_build_options .set_use_shardy_partitioner (true );
798
+ options.executable_build_options .set_use_shardy_partitioner (
799
+ use_shardy_partitioner);
796
800
797
801
// auto partitioning for GPUs is not available in open source version of XLA
798
802
// options.executable_build_options.set_use_auto_spmd_partitioning(true);
@@ -832,12 +836,14 @@ xla::CompileOptions GenerateCompileOptions(int64_t device_id, bool is_sharded,
832
836
extern " C" xla::PjRtLoadedExecutable *
833
837
ClientCompile (PjRtClient *client, MlirModule cmod, int64_t device_id,
834
838
bool is_sharded, const int64_t *mesh_ids, int64_t num_mesh_ids,
835
- const char *xla_gpu_cuda_data_dir) {
836
- CompileOptions options = GenerateCompileOptions (
837
- device_id, is_sharded, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir);
839
+ const char *xla_gpu_cuda_data_dir, bool use_shardy_partitioner) {
840
+ CompileOptions options =
841
+ GenerateCompileOptions (device_id, is_sharded, mesh_ids, num_mesh_ids,
842
+ xla_gpu_cuda_data_dir, use_shardy_partitioner);
838
843
839
844
mlir::ModuleOp cmod_op = cast<ModuleOp>(*unwrap (cmod));
840
- if (is_sharded) {
845
+
846
+ if (is_sharded && use_shardy_partitioner) {
841
847
// https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460
842
848
auto status = xla::ExportShardyForHloRoundTrip (cmod_op);
843
849
if (!status.ok ()) {
@@ -1123,6 +1129,10 @@ extern "C" void InitializePasses(MlirDialectRegistry creg) {
1123
1129
// xla + shardy specific passes
1124
1130
xla::sdy::registerSdyRoundTripExportPipeline ();
1125
1131
xla::sdy::registerSdyRoundTripImportPipeline ();
1132
+ mlir::sdy::registerAllSdyPassesAndPipelines ();
1133
+ xla::sdy::registerStablehloExportPipeline ();
1134
+ xla::sdy::registerStablehloImportPipeline ();
1135
+ xla::sdy::registerStablehloImportShardingsPass ();
1126
1136
}
1127
1137
1128
1138
extern " C" void InitializeRegistry (MlirDialectRegistry creg) {
@@ -1363,14 +1373,15 @@ ifrt_pjrt_array_create(ifrt::PjRtClient *client,
1363
1373
extern " C" xla::ifrt::LoadedExecutable *
1364
1374
ifrt_compile (ifrt::Client *client, MlirModule cmod, int64_t device_id,
1365
1375
bool is_sharded, const int64_t *mesh_ids, int64_t num_mesh_ids,
1366
- const char *xla_gpu_cuda_data_dir) {
1367
- xla::CompileOptions compile_options = GenerateCompileOptions (
1368
- device_id, is_sharded, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir);
1376
+ const char *xla_gpu_cuda_data_dir, bool use_shardy_partitioner) {
1377
+ xla::CompileOptions compile_options =
1378
+ GenerateCompileOptions (device_id, is_sharded, mesh_ids, num_mesh_ids,
1379
+ xla_gpu_cuda_data_dir, use_shardy_partitioner);
1369
1380
auto options = std::make_unique<xla::ifrt::XlaCompileOptions>(
1370
1381
xla::ifrt::XlaCompileOptions (compile_options));
1371
1382
1372
1383
mlir::ModuleOp cmod_op = cast<ModuleOp>(*unwrap (cmod));
1373
- if (is_sharded) {
1384
+ if (is_sharded && use_shardy_partitioner ) {
1374
1385
// https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460
1375
1386
auto status = xla::ExportShardyForHloRoundTrip (cmod_op);
1376
1387
if (!status.ok ()) {
0 commit comments