Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
Add lm-eval comparison script (#99)
Browse files Browse the repository at this point in the history
Add a basic lm-eval harness test script comparing to HF and an action to
the nightly build that tests a marlin model in vLLM to a GPTQ model in
HF Transformers, essentially expressed through `python
.github/scripts/lm_eval_compare_hf_vs_vllm.py --hf_pretrained
nm-testing/zephyr-beta-7b-gptq-g128 --vllm_pretrained
nm-testing/zephyr-beta-7b-marlin-g128`
  • Loading branch information
mgoin authored Mar 12, 2024
1 parent 7bbd2cc commit e2d9050
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 0 deletions.
26 changes: 26 additions & 0 deletions .github/actions/nm-lm-eval-accuracy/action.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
name: run lm-eval accuracy test
description: 'run lm-eval accuracy test'
inputs:
python:
description: 'python version, e.g. 3.10.12'
required: true
venv:
description: 'name for python virtual environment'
required: true
runs:
using: composite
steps:
- id: lm-eval
run: |
COMMIT=${{ github.sha }}
VENV="${{ inputs.venv }}-${COMMIT:0:7}"
source $(pyenv root)/versions/${{ inputs.python }}/envs/${VENV}/bin/activate
pip3 install git+https://github.com/EleutherAI/lm-evaluation-harness.git
pip3 install optimum auto-gptq
SUCCESS=0
python .github/scripts/lm_eval_compare_hf_vs_vllm.py --hf_pretrained nm-testing/zephyr-beta-7b-gptq-g128 --vllm_pretrained nm-testing/zephyr-beta-7b-marlin-g128 || SUCCESS=$?
echo "test=${SUCCESS}" >> "$GITHUB_OUTPUT"
exit ${SUCCESS}
shell: bash
126 changes: 126 additions & 0 deletions .github/scripts/lm_eval_compare_hf_vs_vllm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import argparse
import os
from typing import Dict, List, Tuple

import numpy as np
import scipy.stats

import lm_eval
import lm_eval.models.utils

os.environ["TOKENIZERS_PARALLELISM"] = "false"


def calculate_z_value(res1: Dict, res2: Dict) -> Tuple[float, float]:
acc1, acc2 = res1["acc,none"], res2["acc,none"]
st_err1, st_err2 = res1["acc_stderr,none"], res2["acc_stderr,none"]
Z = (acc1 - acc2) / np.sqrt((st_err1**2) + (st_err2**2))
# Determining the p-value
p_value = 2 * scipy.stats.norm.sf(abs(Z)) # two-tailed test
return Z, p_value


def print_results(data_to_print: List = None,
results_dict: Dict = None,
alpha: float = None):
model1_data, model2_data = data_to_print
for task in model1_data:
print(f"Task: {task}")
print(f"HF Accuracy: {model1_data[task]['acc,none']}")
print(f"vLLM Accuracy: {model2_data[task]['acc,none']}")
print(f"HF StdErr: {model1_data[task]['acc_stderr,none']}")
print(f"vLLM StdErr: {model2_data[task]['acc_stderr,none']}")
z = results_dict[task]["z"]
p_value = results_dict[task]["p_value"]
result = "PASS" if p_value > alpha else "FAIL"
print(f"Z-Score: {z}, P-Value: {p_value}, p > {alpha}: {result}\n")


def check_passing_score(results_dict: Dict = None,
alpha: float = None) -> bool:
for task in results_dict:
p_value = task["p_value"]
if p_value <= alpha:
return False
return True


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--hf_pretrained",
default="EleutherAI/pythia-70m",
help="name of model to compare as baseline")
parser.add_argument("--vllm_pretrained",
default="EleutherAI/pythia-70m",
help="name of model to compare as difference")
parser.add_argument("--hf_args",
help="huggingface model args <arg>=<value>",
default="")
parser.add_argument("--vllm_args",
help="vllm model args <arg>=<value>",
default="")
parser.add_argument("--tasks", type=str, default="arc_easy,hellaswag")
parser.add_argument(
"--limit",
type=float,
default=100,
)
parser.add_argument(
"--alpha",
type=float,
default=0.05,
help="Significance level for two-tailed z-test",
)
parser.add_argument(
"--device",
type=str,
default="cuda",
)
parser.add_argument(
"--batch",
type=str,
default=4,
)
parser.add_argument(
"--verbosity",
type=str,
default="INFO",
help="Logging verbosity",
)
return parser.parse_args()


if __name__ == "__main__":
args = parse_args()
tasks = args.tasks.split(",")
print("Tasks:", tasks)
hf_args, vllm_args = "," + args.hf_args, "," + args.vllm_args
results_hf = lm_eval.simple_evaluate(
model="hf",
model_args=f"pretrained={args.hf_pretrained}" + hf_args,
tasks=tasks,
limit=args.limit,
device=args.device,
batch_size=args.batch,
)
lm_eval.models.utils.clear_torch_cache()
print("Memory stats cleared")
results_vllm = lm_eval.simple_evaluate(
model="vllm",
model_args=f"pretrained={args.vllm_pretrained}" + vllm_args,
tasks=tasks,
limit=args.limit,
device=args.device,
batch_size=args.batch,
)
all_res = {}
for task1, task2 in zip(results_hf["results"].items(),
results_vllm["results"].items()):
assert task1[0] == task2[0]
z, p_value = calculate_z_value(task1[1], task2[1])
all_res[task1[0]] = {"z": z, "p_value": p_value}
print_results([results_hf["results"], results_vllm["results"]], all_res,
args.alpha)
if not check_passing_score:
print("Accuracy test failed!")
exit(1)
11 changes: 11 additions & 0 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,3 +32,14 @@ jobs:
Gi_per_thread: 12
python: "3.10.12"
secrets: inherit

# single gpu
AWS-AVX2-32G-A10G-24G-Accuracy:
uses: ./.github/workflows/nm-lm-eval-accuracy.yml
with:
label: aws-avx2-32G-a10g-24G
timeout: 60
gitref: '${{ github.ref }}'
Gi_per_thread: 12
python: "3.10.12"
secrets: inherit
98 changes: 98 additions & 0 deletions .github/workflows/nm-lm-eval-accuracy.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
name: nm-lm-eval-accuracy
on:
# makes workflow reusable
workflow_call:
inputs:
label:
description: "requested runner label (specifies instance)"
type: string
required: true
timeout:
description: "maximum time runner will be up"
type: string
required: true
gitref:
description: "git commit hash or branch name"
type: string
required: true
Gi_per_thread:
description: 'requested GiB to reserve per thread'
type: string
required: true
python:
description: "python version, e.g. 3.10.12"
type: string
required: true

# makes workflow manually callable
workflow_dispatch:
inputs:
label:
description: "requested runner label (specifies instance)"
type: string
required: true
timeout:
description: "maximum time runner will be up"
type: string
required: true
gitref:
description: "git commit hash or branch name"
type: string
required: true
Gi_per_thread:
description: 'requested GiB to reserve per thread'
type: string
required: true
python:
description: "python version, e.g. 3.10.12"
type: string
required: true

jobs:
LM-EVAL:

runs-on: ${{ inputs.label }}
timeout-minutes: ${{ fromJSON(inputs.timeout) }}

steps:
- name: checkout repository code
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: ${{ inputs.gitref }}
submodules: recursive

- name: setenv
id: setenv
uses: ./.github/actions/nm-set-env/
with:
hf_token: ${{ secrets.NM_HF_TOKEN }}
Gi_per_thread: ${{ inputs.Gi_per_thread }}

- name: set python
id: set_python
uses: ./.github/actions/nm-set-python/
with:
python: ${{ inputs.python }}
venv: TEST

- name: hf cache
id: hf_cache
uses: ./.github/actions/nm-hf-cache/
with:
fs_cache: ${{ secrets.HF_FS_CACHE }}

- name: build
id: build
uses: ./.github/actions/nm-build-vllm/
with:
Gi_per_thread: ${{ inputs.Gi_per_thread }}
python: ${{ inputs.python }}
venv: TEST
pypi: ${{ secrets.NM_PRIVATE_PYPI_LOCATION }}

- name: run lm-eval-accuracy
uses: ./.github/actions/nm-lm-eval-accuracy/
with:
python: ${{ inputs.python }}
venv: TEST

0 comments on commit e2d9050

Please sign in to comment.