Skip to content

Commit cd0f16f

Browse files
committed
fix bug
1 parent ae65c6a commit cd0f16f

File tree

1 file changed

+41
-28
lines changed

1 file changed

+41
-28
lines changed

src/eval.py

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,21 @@
22
Helpers for Evaluations
33
"""
44

5-
import requests
6-
import torch
7-
import torch.nn as nn
5+
import importlib
6+
import json
87
import os, subprocess
9-
from pydantic import BaseModel
10-
import numpy as np
118
import random
12-
import json
13-
from contextlib import redirect_stdout, redirect_stderr
14-
from io import StringIO
159
import sys
16-
import importlib
1710
import tempfile
11+
from contextlib import redirect_stderr, redirect_stdout
12+
from io import StringIO
13+
14+
import numpy as np
15+
import requests
16+
import torch
17+
import torch.nn as nn
18+
from pydantic import BaseModel
19+
1820
from . import utils
1921

2022
REPO_TOP_PATH = os.path.abspath(
@@ -25,14 +27,15 @@
2527
)
2628
KERNEL_BENCH_PATH = os.path.join(REPO_TOP_PATH, "KernelBench")
2729

30+
2831
def import_ModelNew_from_code(code_string):
2932
"""
3033
Writes the provided Python code string to a temporary .py file,
3134
dynamically imports the module so we can access 'ModelNew',
3235
3336
This is a hack in order to allow decorators (useful for triton code) in the custom kernel code
3437
Unfortunately, this means that we cannot delete the tempfile until the model itself is deleted,
35-
so we need to do a bit of garbage collection ourselves (callers responsibility) and delete the tempfile
38+
so we need to do a bit of garbage collection ourselves (callers responsibility) and delete the tempfile
3639
when the model is deleted / before the program exits
3740
The name of the tempfile is returned so we can delete it later.
3841
"""
@@ -179,7 +182,9 @@ def _cleanup_cuda_extensions():
179182
shutil.rmtree(torch_extensions_path)
180183

181184

182-
def graceful_eval_cleanup(curr_context: dict, device: torch.device, tempfile_path: str = None):
185+
def graceful_eval_cleanup(
186+
curr_context: dict, device: torch.device, tempfile_path: str = None
187+
):
183188
"""
184189
Clean up env, gpu cache, and compiled CUDA extensions after evaluation
185190
""" # delete ran-specific function definitions before next eval run
@@ -200,6 +205,7 @@ def graceful_eval_cleanup(curr_context: dict, device: torch.device, tempfile_pat
200205
if tempfile_path:
201206
os.remove(tempfile_path)
202207

208+
203209
def build_compile_cache_legacy(
204210
custom_model_src: str,
205211
verbose: bool = False,
@@ -233,11 +239,12 @@ def build_compile_cache_legacy(
233239
if verbose:
234240
print(f"[Compilation] Compilation Successful, saved cache at: {build_dir}")
235241
except Exception as e:
236-
print(f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}")
242+
print(
243+
f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}"
244+
)
237245
return False, stdout_buffer.getvalue(), str(e)
238-
239-
return True, stdout_buffer.getvalue(), None
240246

247+
return True, stdout_buffer.getvalue(), None
241248

242249

243250
def build_compile_cache(
@@ -273,16 +280,16 @@ def build_compile_cache(
273280
if verbose:
274281
print(f"[Compilation] Compilation Successful, saved cache at: {build_dir}")
275282
except Exception as e:
276-
print(f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}")
283+
print(
284+
f"[Compilation] Failed to compile custom CUDA kernel. Unable to cache, \nError: {e}"
285+
)
277286
return False, stdout_buffer.getvalue(), str(e)
278287

279288
return True, stdout_buffer.getvalue(), None
280289

281290

282291
def build_compile_cache_with_capturing(
283-
custom_model_src: str,
284-
verbose: bool = False,
285-
build_dir: os.PathLike = None
292+
custom_model_src: str, verbose: bool = False, build_dir: os.PathLike = None
286293
) -> tuple[int, str, str]:
287294
"""
288295
Write a temporary python file to compile the custom model on CPU
@@ -304,22 +311,21 @@ def build_compile_cache_with_capturing(
304311
f.write(custom_model_src)
305312

306313
# Execute the temporary Python file and capture output
307-
process = subprocess.Popen(['python', tmp], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
314+
process = subprocess.Popen(
315+
["python", tmp], stdout=subprocess.PIPE, stderr=subprocess.PIPE
316+
)
308317
stdout, stderr = process.communicate()
309318
returncode = process.returncode
310319

311320
# Clean up temporary file
312321
os.remove(tmp)
313322

314-
315323
if verbose:
316324
print("[CPU Precompile] return code: ", returncode)
317-
print("[CPU Precompile] stdout: \n", stdout.decode('utf-8'))
318-
print("[CPU Precompile] stderr: \n", stderr.decode('utf-8'))
319-
320-
return returncode, stdout.decode('utf-8'), stderr.decode('utf-8')
321-
325+
print("[CPU Precompile] stdout: \n", stdout.decode("utf-8"))
326+
print("[CPU Precompile] stderr: \n", stderr.decode("utf-8"))
322327

328+
return returncode, stdout.decode("utf-8"), stderr.decode("utf-8")
323329

324330

325331
def eval_kernel_against_ref(
@@ -331,7 +337,9 @@ def eval_kernel_against_ref(
331337
verbose: bool = False,
332338
measure_performance: bool = False,
333339
build_dir: os.PathLike = None,
334-
device: torch.device = torch.cuda.current_device() if torch.cuda.is_available() else None, # have to run on GPU
340+
device: torch.device = (
341+
torch.cuda.current_device() if torch.cuda.is_available() else None
342+
), # have to run on GPU
335343
) -> KernelExecResult:
336344
"""
337345
Evaluate the custom kernel against the original model
@@ -382,9 +390,12 @@ def eval_kernel_against_ref(
382390

383391
# this is where compilation happens
384392
try:
393+
tempfile_path = None # in case load_custom_model fails
385394
os.environ["TORCH_USE_CUDA_DSA"] = "1" # compile with device side assertion
386395
# add hash for later to distinguish between multi-turn kernels
387-
ModelNew, tempfile_path = load_custom_model(custom_model_src, context, build_dir)
396+
ModelNew, tempfile_path = load_custom_model(
397+
custom_model_src, context, build_dir
398+
)
388399
torch.cuda.synchronize(device=device) # not sure if this is too much
389400
except Exception as e:
390401
print(
@@ -398,7 +409,7 @@ def eval_kernel_against_ref(
398409
print(
399410
f"[Eval] Lock file error during compilation, Please retry. Error: {e}"
400411
)
401-
graceful_eval_cleanup(context, device, tempfile_path)
412+
graceful_eval_cleanup(context, device)
402413
return None
403414
else:
404415
metadata["compilation_error"] = e
@@ -709,11 +720,13 @@ def check_metadata_serializable(metadata: dict):
709720

710721
return metadata
711722

723+
712724
def check_metadata_serializable_all_types(metadata: dict):
713725
"""
714726
Ensure metadata is JSON serializable,
715727
if not, convert non-serializable values to strings recursively
716728
"""
729+
717730
def convert_to_serializable(obj):
718731
if isinstance(obj, dict):
719732
return {k: convert_to_serializable(v) for k, v in obj.items()}

0 commit comments

Comments
 (0)