Skip to content

Commit

Permalink
fix test code
Browse files Browse the repository at this point in the history
  • Loading branch information
AyiStar committed Jul 16, 2024
1 parent 8a4ee15 commit d7a1030
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 35 deletions.
18 changes: 0 additions & 18 deletions test/test_compile.py

This file was deleted.

17 changes: 10 additions & 7 deletions test/test_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,27 @@

@pytest.fixture(params=[0, 1, 2, 3])
def compile_lamm(request):
completed = subprocess.run(
subprocess.run(
args=LAMMCommand.clean(),
capture_output=True,
cwd=LAMM_PROJECT_DIR,
check=False,
)
return subprocess.run(
args=LAMMCommand.compile(opt_level=request.param),
capture_output=True,
timeout=60,
cwd=LAMM_PROJECT_DIR,
shell=True,
check=False,
)
# assert completed.returncode == 0, completed.stderr
return request.param

@pytest.mark.parametrize("dtype", ["f32", "q4_0", "q4_1", "q5_0"])
@pytest.mark.parametrize("dtype", ["f32", "q4_0", "q4_1", "q8_0"])
def test_matmul_correctness(compile_lamm, dtype):
assert compile_lamm.returncode == 0, compile_lamm.stderr.decode("utf-8")
completed = subprocess.run(
args=LAMMCommand.run_benchmark(ggml_type=dtype),
capture_output=True,
timeout=60,
cwd=LAMM_PROJECT_DIR,
check=False,
)
assert completed.returncode == 0, completed.stdout.decode("utf-8")
assert completed.returncode == 0, completed.stdout.decode("utf-8")
16 changes: 6 additions & 10 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,17 @@
class LAMMCommand:

@staticmethod
def compile(opt_level: int=3, debug: bool=False, clean: bool=True) -> list[str]:

clean_cmd = ['make', 'clean', '&&']

def clean() -> list[str]:
return ['make', 'clean']

@staticmethod
def compile(opt_level: int=3, debug: bool=False) -> list[str]:
assert opt_level >= 0 and opt_level <= 3, f'Optimization level only support 0-3, but got {opt_level}'
make_options = [f'LAMM_OPT_LEVEL={opt_level}']
if debug:
make_options.append('LAMM_DEBUG=1')

compile_cmd = ['make', 'benchmark'] + make_options

if clean:
return clean_cmd + compile_cmd
else:
return compile_cmd
return compile_cmd

@staticmethod
def run_benchmark(ggml_type: str, n_threads: int=1, n_iters: int=1) -> list[str]:
Expand Down

0 comments on commit d7a1030

Please sign in to comment.