Skip to content

run_and_check modal script (adapted from the pr from tcapelle) #39

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
323 changes: 323 additions & 0 deletions scripts/run_and_check_modal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,323 @@
import os
import shutil
import importlib.util
import sys
import os
import tempfile

import modal
import pydra
import torch
import numpy as np

from pydra import REQUIRED, Config

from src.eval import eval_kernel_against_ref
from src.utils import set_gpu_arch, read_file


"""
Run a pair of (reference, solution) to check if solution is correct and compute speedup using Modal
Usage:
python3 scripts/run_and_check_modal.py ref_arch_src_path=src/prompts/model_ex_add.py kernel_src_path=src/prompts/model_new_ex_add.py
"""

torch.set_printoptions(precision=4, threshold=10)
app = modal.App("run_and_check")
gpu_arch_mapping = {
"L40S": ["Ada"],
"H100": ["Hopper"],
"A100": ["Ampere"],
"L4": ["Ada"],
"T4": ["Turing"],
"A10G": ["Ampere"],
}


class ScriptConfig(Config):
def __init__(self):
# Required file paths
self.ref_arch_src_path = REQUIRED # Reference implementation
self.kernel_src_path = REQUIRED # Custom kernel implementation
self.gpu = "L40S" # GPU type for modal
self.num_correct_trials = 5 # Number of trials for correctness
self.num_perf_trials = 100 # Number of trials for performance
self.timeout = 300 # Timeout for each trial
self.verbose = False # Verbose logging
self.measure_performance = True # Whether to measure performance
self.build_dir_prefix = "" # Custom build directory prefix
self.clear_cache = False # Whether to clear build cache
self.gpu_arch = ["Ada"] # Default GPU architecture

def __repr__(self):
return f"ScriptConfig({self.to_dict()})"


# Configure Modal image
cuda_version = "12.8.0"
flavor = "devel"
operating_sys = "ubuntu22.04"
tag = f"{cuda_version}-{flavor}-{operating_sys}"

image = (
modal.Image.from_registry(f"nvidia/cuda:{tag}", add_python="3.10")
.apt_install("git",
"gcc-10",
"g++-10",
"clang" # note i skip a step
)
.pip_install( # required to build flash-attn
"anthropic",
"numpy",
"openai",
"packaging",
"pydra_config",
"torch==2.5.0",
"tqdm",
"datasets",
"transformers",
"google-generativeai",
"together",
"pytest",
"ninja",
"utils",
)
)


@app.cls(image=image)
class EvalFunc:
@modal.method()
def evaluate_single_sample_src_modal(
self, ref_arch_src, kernel_src, configs, gpu_arch
):
"""Evaluate a single sample source code against a reference source code"""

set_gpu_arch(gpu_arch)
device = torch.device("cuda:0")

eval_result = eval_kernel_against_ref(
ref_arch_src,
kernel_src,
verbose=configs.verbose,
measure_performance=True,
num_correct_trials=5,
num_perf_trials=100
)

return eval_result

@modal.method()
def measure_program_time(
self,
ref_arch_src,
num_trials,
use_torch_compile=False,
torch_compile_backend=None,
torch_compile_options=None,
gpu_arch=None,
):
"""Measure the execution time of a reference program"""

# Setup
if gpu_arch:
set_gpu_arch(gpu_arch)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Create temporary module
temp_dir = tempfile.mkdtemp()
ref_module_path = os.path.join(temp_dir, "ref_module.py")

with open(ref_module_path, "w") as f:
f.write(ref_arch_src)

# Load reference module
spec = importlib.util.spec_from_file_location("ref_module", ref_module_path)
ref_module = importlib.util.module_from_spec(spec)
sys.modules["ref_module"] = ref_module
spec.loader.exec_module(ref_module)

# Create model instance
if hasattr(ref_module, "get_init_inputs"):
init_inputs = ref_module.get_init_inputs()
init_inputs = [
(
x
if (isinstance(x, torch.Tensor) and x.device == device)
else (x.to(device) if isinstance(x, torch.Tensor) else x)
)
for x in init_inputs
]
ref_model = ref_module.Model(*init_inputs).to(device)
else:
ref_model = ref_module.Model().to(device)

# Apply torch.compile if needed
if use_torch_compile:
if torch_compile_backend is not None:
if (
torch_compile_options is not None
and torch_compile_options != "default"
):
compile_options = (
{"mode": torch_compile_options}
if torch_compile_options in ["max-autotune", "reduce-overhead"]
else {}
)
ref_model = torch.compile(
ref_model,
backend=torch_compile_backend,
options=compile_options,
)
else:
ref_model = torch.compile(ref_model, backend=torch_compile_backend)
else:
ref_model = torch.compile(ref_model)

# Generate inputs
if hasattr(ref_module, "get_inputs"):
inputs = ref_module.get_inputs()
inputs = [
(
x
if (isinstance(x, torch.Tensor) and x.device == device)
else (x.to(device) if isinstance(x, torch.Tensor) else x)
)
for x in inputs
]
elif hasattr(ref_module, "INPUT_SHAPE"):
input_shape = ref_module.INPUT_SHAPE
if isinstance(input_shape, tuple):
inputs = (torch.randn(input_shape, device=device),)
elif isinstance(input_shape, list):
inputs = tuple(
torch.randn(shape, device=device) for shape in input_shape
)
else:
raise ValueError(f"Invalid INPUT_SHAPE: {input_shape}")
else:
# Infer inputs from model
if hasattr(ref_model, "forward"):
argcount = ref_model.forward.__code__.co_argcount
inputs = tuple(
torch.randn(1, 128, device=device) for _ in range(argcount - 1)
)
else:
raise ValueError("Could not determine appropriate inputs for the model")

# Warmup
for _ in range(10):
ref_model(*inputs)

# Timing
torch.cuda.synchronize()
times = []
for _ in range(num_trials):
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

start.record()
ref_model(*inputs)
end.record()

torch.cuda.synchronize()
times.append(start.elapsed_time(end))

# Clean up
try:
os.remove(ref_module_path)
os.rmdir(temp_dir)
except OSError:
shutil.rmtree(temp_dir, ignore_errors=True)

# Calculate statistics
times = np.array(times)
return {
"mean": float(np.mean(times)),
"std": float(np.std(times)),
"min": float(np.min(times)),
"max": float(np.max(times)),
"median": float(np.median(times)),
}


@pydra.main(base=ScriptConfig)
def main(config: ScriptConfig):
print("Running with config", config)

# Read source files
ref_arch_src = read_file(config.ref_arch_src_path)
kernel_src = read_file(config.kernel_src_path)

# Prepare GPU architecture settings
gpu_arch = gpu_arch_mapping.get(config.gpu, config.gpu_arch)
print(f"[INFO] Using GPU architecture: {gpu_arch}")

# Start Evaluation
with app.run():
# Evaluate kernel against reference code
print("[INFO] Evaluating kernel against reference code")
kernel_eval_result = EvalFunc.with_options(
gpu=config.gpu
)().evaluate_single_sample_src_modal.remote(
ref_arch_src=ref_arch_src,
kernel_src=kernel_src,
configs=config,
gpu_arch=gpu_arch,
)
print(f"Raw result: {kernel_eval_result}, {type(kernel_eval_result)}")

kernel_exec_time = kernel_eval_result.runtime

# Measure baseline time for PyTorch Eager
print("[INFO] Measuring reference program time (eager mode)")
ref_time_eager_result = EvalFunc.with_options(
gpu=config.gpu
)().measure_program_time.remote(
ref_arch_src=ref_arch_src,
num_trials=config.num_perf_trials,
use_torch_compile=False,
torch_compile_backend=None,
torch_compile_options=None,
gpu_arch=gpu_arch,
)
ref_exec_eager_time = ref_time_eager_result.get("mean", None)

# Measure Torch Compile time
print("[INFO] Measuring reference program time (torch.compile)")
ref_time_compile_result = EvalFunc.with_options(
gpu=config.gpu
)().measure_program_time.remote(
ref_arch_src=ref_arch_src,
num_trials=config.num_perf_trials,
use_torch_compile=True,
torch_compile_backend="inductor",
torch_compile_options="default",
gpu_arch=gpu_arch,
)
ref_exec_compile_time = ref_time_compile_result.get("mean", None)

# Print results
print("=" * 40)
print(f"[Eval] Kernel eval result: {kernel_eval_result}")
print("-" * 40)
print(f"[Timing] PyTorch Reference Eager exec time: {ref_exec_eager_time} ms")
print(f"[Timing] PyTorch Reference torch.compile time: {ref_exec_compile_time} ms")
print(f"[Timing] Custom Kernel exec time: {kernel_exec_time} ms")
print("-" * 40)

if kernel_eval_result.correctness:
print(
f"[Speedup] Speedup over eager: {ref_exec_eager_time / kernel_exec_time:.2f}x"
)
print(
f"[Speedup] Speedup over torch.compile: {ref_exec_compile_time / kernel_exec_time:.2f}x"
)
else:
print("[Speedup] Speedup Not Available as Kernel did not pass correctness")

print("=" * 40)


if __name__ == "__main__":
main()