Skip to content

[0.5.0-UT] pallas test fused attention fwd #430

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: rocm-jaxlib-v0.5.0
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions tests/pallas/gpu_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,21 @@
softmax = None
import jax.numpy as jnp
import numpy as np

from pathlib import Path

# TODO(sharadmv): Update signatures of pallas_call to correct inputs/outputs.
# pylint: disable=no-value-for-parameter

def get_rocm_version():
rocm_path = os.environ.get("ROCM_PATH", "/opt/rocm")
version_path = Path(rocm_path) / ".info" / "version"
if not version_path.exists():
raise FileNotFoundError(f"Expected ROCm version file at {version_path}")
version_str = version_path.read_text().strip()
major, minor, *_ = version_str.split(".")
return int(major), int(minor)


config.parse_flags_with_absl()


Expand Down Expand Up @@ -149,11 +159,12 @@ def setUp(self):
self.skipTest("Not intended for TPU")

# Sequence length is reduced for ROCm due to large dimension not
# fitting in shared memory. Higher dimension causes "XlaRuntimeError:
# RESOURCE_EXHAUSTED: Shared memory size limit exceeded" error.
# fitting in shared memory in ROCM versions below 6.5.0. Higher
# dimension causes "XlaRuntimeError: RESOURCE_EXHAUSTED: Shared
# memory size limit exceeded" error.
@jtu.sample_product(
batch_size=(1, 2),
seq_len=(32, 64) if jtu.is_device_rocm else (128, 384),
seq_len=(32, 64) if jtu.is_device_rocm and get_rocm_version() < (6, 5) else (128, 384),
num_heads=(1, 2, 8),
head_dim=(32, 64, 128),
block_q=(64, 128),
Expand Down
Loading