Skip to content

Commit a6b729c

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] No longer implicitly trigger single-buffering when running out of VMEM
It's just more predictable to OOM and let the user explicitly opt into single buffering using `pipeline_mode=pl.Buffered(1)`. We bump the version number so that the old artifacts still get the relaxed treatment. PiperOrigin-RevId: 820593493
1 parent 8bda006 commit a6b729c

File tree

6 files changed

+24
-17
lines changed

6 files changed

+24
-17
lines changed

jax/_src/tpu_custom_call.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,20 +58,20 @@
5858
# mode: for 1 month when exporting, or when using old cloud TPU.
5959
#
6060
# This can be achieved by adding:
61-
# if ctx.is_forward_compat() or is_cloud_tpu_older_than(<today>):
61+
# if ctx.is_forward_compat() or backend is None or is_cloud_tpu_older_than(<today>):
6262
# return <previous_serialization_version>
6363
# return None
6464
#
6565
# We should also add a TODO to remove the conditional one month later.
6666
def get_ir_version(ctx: mlir.LoweringRuleContext) -> int | None:
6767
backend = ctx.module_context.get_backend(optional=True)
68-
# TODO(naumsmogers): remove the forward compatibility check after 2025-09-14.
68+
# TODO(apaszke): remove the forward compatibility check after 2025-12-5.
6969
if (
7070
ctx.is_forward_compat()
7171
or backend is None
72-
or is_cloud_tpu_older_than(2025, 8, 14, backend)
72+
or is_cloud_tpu_older_than(2025, 11, 5, backend)
7373
):
74-
return 7
74+
return 8
7575
return None
7676

7777

jaxlib/mosaic/dialect/tpu/transforms/serde.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ constexpr StringRef kMangledDialect = "stable_mosaic.";
4040
constexpr StringRef kVersionAttrName = "stable_mosaic.version";
4141
// When this is bumped, we should file a TODO to update the forward-compatible
4242
// version in tpu_custom_call.py in a month!
43-
constexpr int kVersion = 8;
43+
constexpr int kVersion = 9;
4444

4545
using SerdeRuleType = jaxlib::mosaic::SerdeRuleType;
4646

@@ -311,7 +311,8 @@ void MosaicSerdePass::runOnOperation() {
311311
{.dialect_prefix = kMangledDialect,
312312
.highest_version = kVersion,
313313
.version_attr_name = kVersionAttrName,
314-
.serialize_version = serialize_version}))) {
314+
.serialize_version = serialize_version},
315+
/*keep_version_attr=*/keep_version_attr))) {
315316
signalPassFailure();
316317
}
317318
}

jaxlib/mosaic/dialect/tpu/transforms/serde.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ struct MosaicSerdePass : public jaxlib::mlir::Pass<MosaicSerdePass, ModuleOp> {
6262

6363
protected:
6464
::mlir::Pass::Option<bool> serialize{*this, "serialize", llvm::cl::desc("")};
65+
::mlir::Pass::Option<bool> keep_version_attr{
66+
*this, "keep-version-attr", llvm::cl::desc(""), llvm::cl::init(true)};
6567
::mlir::Pass::Option<int> target_version{*this, "target-version",
6668
llvm::cl::desc("")};
6769
};

jaxlib/mosaic/serde.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ std::optional<llvm::StringRef> demangle(llvm::StringRef name,
5555
mlir::LogicalResult RunSerde(
5656
mlir::ModuleOp module, const llvm::StringMap<SerdeRuleType>& upgrade_rules,
5757
const llvm::StringMap<SerdeRuleType>& downgrade_rules, bool serialize,
58-
SerdeOptions options) {
58+
SerdeOptions options, bool keep_version_attr) {
5959
int version = options.highest_version;
6060
int serialize_version = options.serialize_version;
6161
if (!serialize && serialize_version != -1) {
@@ -91,7 +91,9 @@ mlir::LogicalResult RunSerde(
9191
return mlir::failure();
9292
}
9393
version = version_attr.getInt();
94-
module->removeAttr(options.version_attr_name);
94+
if (!keep_version_attr) {
95+
module->removeAttr(options.version_attr_name);
96+
}
9597
}
9698
std::string storage;
9799
// Explicitly use a post-order walk to allow for deleting operations on the

jaxlib/mosaic/serde.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,9 @@ using SerdeRuleType =
4949
// Run serialization or deserialization on the given module.
5050
::mlir::LogicalResult RunSerde(
5151
::mlir::ModuleOp module,
52-
const llvm::StringMap<SerdeRuleType> &upgrade_rules,
53-
const llvm::StringMap<SerdeRuleType> &downgrade_rules, bool serialize,
54-
SerdeOptions options);
52+
const llvm::StringMap<SerdeRuleType>& upgrade_rules,
53+
const llvm::StringMap<SerdeRuleType>& downgrade_rules, bool serialize,
54+
SerdeOptions options, bool keep_version_attr = false);
5555

5656
} // namespace jaxlib::mosaic
5757

tests/pallas/tpu_pallas_test.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2101,7 +2101,7 @@ def kernel(x_ref, y_ref):
21012101
pl.Buffered(1),
21022102
pl.Buffered(2),
21032103
])
2104-
def test_vmem_oom_error_message_basics(self, pmode):
2104+
def test_vmem_oom_error_message_basics(self, pmode: pl.Buffered):
21052105
if not jtu.if_cloud_tpu_at_least(2025, 10, 14):
21062106
self.skipTest('Support added on Oct 14, 2025')
21072107

@@ -2152,11 +2152,13 @@ def index_map(i, j):
21522152
f' full shape is f32[{shape[0]},{shape[1]}].',
21532153
error_message,
21542154
)
2155-
# When VMEM is OOM, double buffering is disabled.
2156-
self.assertIn(
2157-
'This allocation is single buffered.',
2158-
error_message,
2159-
)
2155+
if jtu.if_cloud_tpu_at_least(2025, 11, 5):
2156+
self.assertIn(
2157+
'This allocation is single buffered.'
2158+
if pmode.buffer_count == 1
2159+
else 'This allocation has 2 buffering levels',
2160+
error_message,
2161+
)
21602162

21612163
def test_vmem_oom_error_message_dynamic_grid_scalar_prefetch_and_vmem_scratch(
21622164
self,

0 commit comments

Comments
 (0)