diff --git a/shark/torch_mlir_utils.py b/shark/torch_mlir_utils.py index b7dfc72b74..433b4efaa0 100644 --- a/shark/torch_mlir_utils.py +++ b/shark/torch_mlir_utils.py @@ -15,6 +15,8 @@ from torch_mlir.ir import StringAttr import torch_mlir from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend +import tempfile +from shark.parser import shark_args def get_module_name_for_asm_dump(module): @@ -62,6 +64,8 @@ def get_torch_mlir_module( if jit_trace: ignore_traced_shapes = True + tempfile.tempdir = shark_args.repro_dir + module = torch_mlir.compile( module, input,