Skip to content

Commit 1ca7386

Browse files
authored
feat(JLL): changes for #926 (#927)
1 parent c003b22 commit 1ca7386

File tree

1 file changed

+21
-10
lines changed

1 file changed

+21
-10
lines changed

deps/ReactantExtra/API.cpp

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,11 @@
8989

9090
// shardy
9191
#include "shardy/dialect/sdy/ir/dialect.h"
92+
#include "shardy/dialect/sdy/transforms/passes.h"
9293
#include "shardy/integrations/c/attributes.h"
9394
#include "xla/pjrt/mlir_to_hlo.h"
9495
#include "xla/service/spmd/shardy/stablehlo_round_trip/export_shardings.h"
96+
#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_export.h"
9597
#include "xla/service/spmd/shardy/stablehlo_round_trip/stablehlo_import.h"
9698

9799
// IFRT
@@ -780,7 +782,8 @@ extern "C" void FutureAwait(FutureType *Future) { Future->Await(); }
780782
xla::CompileOptions GenerateCompileOptions(int64_t device_id, bool is_sharded,
781783
const int64_t *mesh_ids,
782784
int64_t num_mesh_ids,
783-
const char *xla_gpu_cuda_data_dir) {
785+
const char *xla_gpu_cuda_data_dir,
786+
bool use_shardy_partitioner) {
784787
xla::CompileOptions options;
785788
options.executable_build_options.mutable_debug_options()
786789
->set_xla_gpu_cuda_data_dir(xla_gpu_cuda_data_dir);
@@ -792,7 +795,8 @@ xla::CompileOptions GenerateCompileOptions(int64_t device_id, bool is_sharded,
792795
options.executable_build_options.set_num_partitions(num_mesh_ids);
793796

794797
options.executable_build_options.set_use_spmd_partitioning(true);
795-
options.executable_build_options.set_use_shardy_partitioner(true);
798+
options.executable_build_options.set_use_shardy_partitioner(
799+
use_shardy_partitioner);
796800

797801
// auto partitioning for GPUs is not available in open source version of XLA
798802
// options.executable_build_options.set_use_auto_spmd_partitioning(true);
@@ -832,12 +836,14 @@ xla::CompileOptions GenerateCompileOptions(int64_t device_id, bool is_sharded,
832836
extern "C" xla::PjRtLoadedExecutable *
833837
ClientCompile(PjRtClient *client, MlirModule cmod, int64_t device_id,
834838
bool is_sharded, const int64_t *mesh_ids, int64_t num_mesh_ids,
835-
const char *xla_gpu_cuda_data_dir) {
836-
CompileOptions options = GenerateCompileOptions(
837-
device_id, is_sharded, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir);
839+
const char *xla_gpu_cuda_data_dir, bool use_shardy_partitioner) {
840+
CompileOptions options =
841+
GenerateCompileOptions(device_id, is_sharded, mesh_ids, num_mesh_ids,
842+
xla_gpu_cuda_data_dir, use_shardy_partitioner);
838843

839844
mlir::ModuleOp cmod_op = cast<ModuleOp>(*unwrap(cmod));
840-
if (is_sharded) {
845+
846+
if (is_sharded && use_shardy_partitioner) {
841847
// https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460
842848
auto status = xla::ExportShardyForHloRoundTrip(cmod_op);
843849
if (!status.ok()) {
@@ -1123,6 +1129,10 @@ extern "C" void InitializePasses(MlirDialectRegistry creg) {
11231129
// xla + shardy specific passes
11241130
xla::sdy::registerSdyRoundTripExportPipeline();
11251131
xla::sdy::registerSdyRoundTripImportPipeline();
1132+
mlir::sdy::registerAllSdyPassesAndPipelines();
1133+
xla::sdy::registerStablehloExportPipeline();
1134+
xla::sdy::registerStablehloImportPipeline();
1135+
xla::sdy::registerStablehloImportShardingsPass();
11261136
}
11271137

11281138
extern "C" void InitializeRegistry(MlirDialectRegistry creg) {
@@ -1363,14 +1373,15 @@ ifrt_pjrt_array_create(ifrt::PjRtClient *client,
13631373
extern "C" xla::ifrt::LoadedExecutable *
13641374
ifrt_compile(ifrt::Client *client, MlirModule cmod, int64_t device_id,
13651375
bool is_sharded, const int64_t *mesh_ids, int64_t num_mesh_ids,
1366-
const char *xla_gpu_cuda_data_dir) {
1367-
xla::CompileOptions compile_options = GenerateCompileOptions(
1368-
device_id, is_sharded, mesh_ids, num_mesh_ids, xla_gpu_cuda_data_dir);
1376+
const char *xla_gpu_cuda_data_dir, bool use_shardy_partitioner) {
1377+
xla::CompileOptions compile_options =
1378+
GenerateCompileOptions(device_id, is_sharded, mesh_ids, num_mesh_ids,
1379+
xla_gpu_cuda_data_dir, use_shardy_partitioner);
13691380
auto options = std::make_unique<xla::ifrt::XlaCompileOptions>(
13701381
xla::ifrt::XlaCompileOptions(compile_options));
13711382

13721383
mlir::ModuleOp cmod_op = cast<ModuleOp>(*unwrap(cmod));
1373-
if (is_sharded) {
1384+
if (is_sharded && use_shardy_partitioner) {
13741385
// https://github.com/openxla/xla/blob/b3c641b05692f3712fb3c272e38665fdfa28bdf8/xla/python/py_client.cc#L460
13751386
auto status = xla::ExportShardyForHloRoundTrip(cmod_op);
13761387
if (!status.ok()) {

0 commit comments

Comments
 (0)