Skip to content

#sdy Fallback to GSPMD in JAX export if the loaded module was lowered for GSPMD. #29033

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions jax/_src/export/_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -1431,9 +1431,16 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
submodule_bc = mlir.module_to_bytecode(submodule)
shardy_enabled = _jax.sdy.lowered_with_shardy(submodule_bc)
if shardy_enabled:
if not config.use_shardy_partitioner.value:
raise ValueError(
"The function was exported with shardy enabled but you are calling "
"it with Shardy disabled. Please enable Shardy "
"- `jax_use_shardy_partitioner=True`.")
submodule = ir.Module.parse(
_jax.sdy.sdy_round_trip_import_shardings(submodule_bc)
)
elif config.use_shardy_partitioner.value:
shardy_enabled = True

with submodule.context:
pipeline = passmanager.PassManager.parse(
Expand All @@ -1444,7 +1451,7 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
if shardy_enabled:
sdy_mesh_axes = _jax.sdy.get_mesh(mlir.module_to_bytecode(submodule))
mesh = (mesh_lib.AbstractMesh(*list(zip(*sdy_mesh_axes))[::-1])
if sdy_mesh_axes else mesh_lib.empty_abstract_mesh)
if sdy_mesh_axes else None)

axis_context = ctx.module_context.axis_context
if isinstance(axis_context, sharding_impls.ShardingContext):
Expand Down Expand Up @@ -1473,15 +1480,15 @@ def _call_exported_lowering(ctx: mlir.LoweringRuleContext, *args,
)

# Apply in_shardings
if shardy_enabled:
if mesh:
args = tuple(
wrap_with_sharding(
ctx, x, x_aval,
_hlo_sharding_to_named_sharding(x_sharding, mesh)) # type: ignore[arg-type]
_hlo_sharding_to_named_sharding(x_sharding, mesh), use_shardy=True) # type: ignore[arg-type]
for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings_hlo))
else:
args = tuple(
wrap_with_sharding(ctx, x, x_aval, x_sharding)
wrap_with_sharding(ctx, x, x_aval, x_sharding, use_shardy=False)
for x, x_aval, x_sharding in zip(args, ctx.avals_in, exported.in_shardings_hlo))

symtab = ir.SymbolTable(submodule.operation)
Expand Down Expand Up @@ -1570,14 +1577,15 @@ def convert_shape(x: ir.Value, x_aval: core.AbstractValue, new_aval: core.Abstra
for out, out_aval, refined_out_aval in zip(call.results[len(ordered_effects):],
exported.out_avals, ctx.avals_out))
# Apply out_shardings
if shardy_enabled:
if mesh:
results = tuple(
wrap_with_sharding(
ctx, x, x_aval, _hlo_sharding_to_named_sharding(x_sharding, mesh)) # type: ignore[arg-type]
ctx, x, x_aval, _hlo_sharding_to_named_sharding(x_sharding, mesh),
use_shardy=True) # type: ignore[arg-type]
for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings_hlo))
else:
results = tuple(
wrap_with_sharding(ctx, x, x_aval, x_sharding)
wrap_with_sharding(ctx, x, x_aval, x_sharding, use_shardy=False)
for x, x_aval, x_sharding in zip(results, ctx.avals_out, exported.out_shardings_hlo))
return results

Expand All @@ -1588,12 +1596,14 @@ def wrap_with_sharding(
ctx: mlir.LoweringRuleContext,
x: ir.Value,
x_aval: core.AbstractValue,
x_sharding: sharding_impls.NamedSharding | HloSharding | None,
x_sharding: sharding_impls.NamedSharding | sharding_impls.GSPMDSharding | None,
use_shardy: bool,
) -> ir.Value:
if x_sharding is None:
return x
if config.use_shardy_partitioner.value:
if use_shardy:
x_sharding = x_sharding._to_sdy_sharding(x_aval.ndim) # type: ignore
else:
x_sharding = x_sharding.to_proto() # type: ignore
return mlir.wrap_with_sharding_op(ctx, x, x_aval, x_sharding) # type: ignore[arg-type]
return mlir.wrap_with_sharding_op(ctx, x, x_aval, x_sharding, # type: ignore[arg-type]
allow_shardy_lowering=use_shardy)
4 changes: 2 additions & 2 deletions jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -2731,7 +2731,7 @@ def lower_with_sharding_in_types(ctx, op, aval, sharding_proto=None):


def set_sharding(op, sharding: xc.OpSharding | SdyArray | SdyArrayList):
if config.use_shardy_partitioner.value:
if isinstance(sharding, (SdyArray, SdyArrayList)):
op.attributes["sdy.sharding"] = get_sharding_attr(sharding)
else:
op.attributes["mhlo.sharding"] = get_sharding_attr(sharding)
Expand All @@ -2740,7 +2740,7 @@ def set_sharding(op, sharding: xc.OpSharding | SdyArray | SdyArrayList):
def get_sharding_attr(
sharding: xc.OpSharding | SdyArray | SdyArrayList
) -> ir.Attribute:
if config.use_shardy_partitioner.value:
if isinstance(sharding, (SdyArray, SdyArrayList)):
return sharding.build() # type: ignore
else:
# If there are very large numbers of devices, use the proto representation.
Expand Down
5 changes: 4 additions & 1 deletion jaxlib/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -869,12 +869,14 @@ cc_library(
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:cord",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@nanobind",
"@shardy//shardy/dialect/sdy/ir:dialect",
"@stablehlo//:stablehlo_ops",
"@tsl//tsl/platform:fingerprint",
"@tsl//tsl/platform:ml_dtypes",
"@tsl//tsl/profiler/lib:traceme",
Expand Down Expand Up @@ -914,6 +916,7 @@ cc_library(
"@xla//xla/python/pjrt_ifrt:pjrt_dtype",
"@xla//xla/python/pjrt_ifrt:xla_ifrt",
"@xla//xla/service:platform_util",
"@xla//xla/service/spmd/shardy:constants",
"@xla//xla/tsl/concurrency:ref_count",
"@xla//xla/tsl/framework:allocator",
"@xla//xla/tsl/platform:env",
Expand Down
71 changes: 71 additions & 0 deletions jaxlib/py_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ limitations under the License.
#include "absl/strings/string_view.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/OwningOpRef.h"
#include "mlir/IR/Visitors.h"
#include "mlir/Pass/PassManager.h"
#include "nanobind/nanobind.h"
#include "nanobind/stl/optional.h" // IWYU pragma: keep
Expand All @@ -48,6 +50,7 @@ limitations under the License.
#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep
#include "nanobind/stl/variant.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
#include "shardy/dialect/sdy/ir/dialect.h"
#include "jaxlib/guard_lib.h"
#include "jaxlib/nb_class_ptr.h"
#include "jaxlib/py_array.h"
Expand All @@ -60,6 +63,7 @@ limitations under the License.
#include "jaxlib/python_ref_manager.h"
#include "jaxlib/sharding.h"
#include "jaxlib/traceback.h"
#include "stablehlo/dialect/StablehloOps.h"
#include "xla/literal.h"
#include "xla/pjrt/exceptions.h"
#include "xla/pjrt/mlir_to_hlo.h"
Expand Down Expand Up @@ -88,6 +92,7 @@ limitations under the License.
#include "xla/python/types.h"
#include "xla/python/version.h"
#include "xla/service/platform_util.h" // IWYU pragma: keep
#include "xla/service/spmd/shardy/constants.h"
#include "xla/shape.h"
#include "xla/status_macros.h"
#include "xla/tsl/concurrency/ref_count.h"
Expand Down Expand Up @@ -398,6 +403,61 @@ MakeIfrtDeserializeExecutableOptions(std::optional<CompileOptions> options,
std::move(ifrt_loaded_host_callbacks));
}

// Returns true if the module has at least one GSPMD attribute or op, like an
// `mhlo.sharding` attribute or `Sharding` custom call.
bool HasGspmdAttrsOrOps(mlir::ModuleOp module) {
bool has_gspmd = false;
// Go over all functions checking the input/output shardings.
module->walk([&has_gspmd](mlir::func::FuncOp func) {
for (int64_t arg_index = 0; arg_index < func.getNumArguments();
++arg_index) {
if (func.getArgAttr(arg_index, sdy::kXlaShardingAttr)) {
has_gspmd = true;
break;
}
}
if (has_gspmd) {
return mlir::WalkResult::interrupt();
}
for (int64_t result_index = 0; result_index < func.getNumResults();
++result_index) {
if (func.getResultAttr(result_index, sdy::kXlaShardingAttr)) {
has_gspmd = true;
break;
}
}
if (has_gspmd) {
return mlir::WalkResult::interrupt();
}
return mlir::WalkResult::advance();
});
// Check the module for a `Sharding` custom call.
if (!has_gspmd) {
module->walk([&has_gspmd](mlir::stablehlo::CustomCallOp custom_call) {
if (custom_call.getCallTargetName() ==
sdy::kShardingCustomCallTargetName) {
has_gspmd = true;
return mlir::WalkResult::interrupt();
}
return mlir::WalkResult::advance();
});
}
return has_gspmd;
}

// Check if the module has any sort of Shardy mesh:
// - `mesh`
// - `maximal_mesh_{X}`
// - `empty_mesh`
bool HasShardyMesh(mlir::ModuleOp module) {
bool has_mesh = false;
module.walk([&](mlir::sdy::MeshOp mesh) {
has_mesh = true;
return mlir::WalkResult::interrupt();
});
return has_mesh;
}

} // namespace

/* static */ absl::StatusOr<nb_class_ptr<PyLoadedExecutable>>
Expand Down Expand Up @@ -459,10 +519,21 @@ PyClient::CompileAndLoad(nb_class_ptr<PyClient> client, std::string mlir_module,
mlir::MLIRContext context;
TF_ASSIGN_OR_RETURN(mlir::OwningOpRef<mlir::ModuleOp> module,
ParseMlirModuleString(mlir_module, context));
if (options.executable_build_options.use_shardy_partitioner() &&
HasGspmdAttrsOrOps(module.get())) {
LOG(WARNING)
<< "Module has GSPMD attrs or ops, but Shardy is enabled. Disabling "
"Shardy and falling back to using GSPMD propagation.";
options.executable_build_options.set_use_shardy_partitioner(false);
}
if (options.executable_build_options.use_shardy_partitioner()) {
// Since Shardy is located in the middle of the XLA pipeline, we need to
// export it before going to HLO while preserving Shardy ops and attrs.
TF_RETURN_IF_ERROR(ExportShardyForHloRoundTrip(*module));
} else if (HasShardyMesh(module.get())) {
// Shardy is not enabled, but the module has shardy ops. Likely due to
// export loading a GSPMD checkpoint. Fall back to GSPMD.
TF_RETURN_IF_ERROR(ExportShardyForGSPMD(*module));
}
return CompileAndLoadIfrtProgram(
client, std::make_unique<xla::ifrt::HloProgram>(module.get()),
Expand Down
45 changes: 43 additions & 2 deletions tests/export_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import re
import unittest

from absl.testing import absltest
from absl.testing import absltest, parameterized
import jax
from jax import lax
from jax import numpy as jnp
Expand Down Expand Up @@ -2008,7 +2008,7 @@ def f(x, y):
r = jax.jit(exp.call, out_shardings=NamedSharding(old_mesh_0, P("old_b")))(a, b)
self.assertAllClose(a + b, r)

def test_lower_wth_different_meshes_axis_names(self):
def test_lower_with_different_meshes_axis_names(self):
mesh1 = jtu.create_mesh((4, 2), ("a", "b"))
mesh2 = jtu.create_mesh((4, 2), ("x", "y"))
@jax.jit
Expand All @@ -2033,6 +2033,47 @@ def f(tree):
else:
get_exported(f)(args)

@parameterized.named_parameters(
("lower_shardy_on_load_shardy_off", True, "Please enable Shardy"),
("lower_shardy_off_load_shardy_on", False, ""),
)
def test_lower_load_with_different_partitioners(self, use_shardy_on_save,
error_msg):
old_shardy = config.use_shardy_partitioner.value
try:
jax.config.update("jax_use_shardy_partitioner", use_shardy_on_save)
mesh = jtu.create_mesh((8,), ("a",))
@jax.jit
def f(x, y):
z = x + y
return jax.lax.with_sharding_constraint(
z, NamedSharding(mesh, P("a")))

args = (
jax.ShapeDtypeStruct(
(32, 32), dtype=np.float32,
sharding=NamedSharding(mesh, P(None, "a"))),
jax.ShapeDtypeStruct(
(32, 32), dtype=np.float32,
sharding=NamedSharding(mesh, P("a"))))

exp = get_exported(f)(*args)

jax.config.update("jax_use_shardy_partitioner", not use_shardy_on_save)

a = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32))
a = jax.device_put(a, NamedSharding(mesh, P(None, "a")))
b = jnp.arange(32 * 32, dtype=np.float32).reshape((32, 32))
b = jax.device_put(b, NamedSharding(mesh, P("a")))

if use_shardy_on_save:
with self.assertRaisesRegex(ValueError, error_msg):
jax.jit(exp.call, out_shardings=NamedSharding(mesh, P("a")))(a, b)
else:
jax.jit(exp.call, out_shardings=NamedSharding(mesh, P("a")))(a, b)
finally:
jax.config.update("jax_use_shardy_partitioner", old_shardy)



if __name__ == "__main__":
Expand Down
Loading