Skip to content

Commit

Permalink
running benchmark on sub-populations
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreFCruz committed Jun 6, 2024
1 parent 7f36990 commit 179e8ab
Show file tree
Hide file tree
Showing 13 changed files with 2,339 additions and 57 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/python-tests-with-tox.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ['3.8', '3.9', '3.10', '3.11', '3.12']
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v3
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/python-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
python-version: ["3.10", "3.11", "3.12"]

steps:
- uses: actions/checkout@v4
Expand Down
25 changes: 19 additions & 6 deletions folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,25 @@
"""
from __future__ import annotations

import dataclasses
import logging
from functools import partial
from pathlib import Path
import dataclasses

import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

from ._io import save_json, load_json
from ._io import load_json, save_json
from ._utils import hash_dict, is_valid_number
from .acs.acs_dataset import ACSDataset
from .acs.acs_questions import acs_multiple_choice_qa_map, acs_numeric_qa_map
from .acs.acs_tasks import ACSTaskMetadata
from .task import TaskMetadata
from .classifier import LLMClassifier
from .dataset import Dataset
from .evaluation import evaluate_predictions
from .plotting import render_evaluation_plots, render_fairness_plots
from .prompting import encode_row_prompt, encode_row_prompt_chat, encode_row_prompt_few_shot
from .task import TaskMetadata

DEFAULT_SEED = 42

Expand All @@ -36,7 +36,7 @@ class BenchmarkConfig:
batch_size: int | None = None
context_size: int | None = None
correct_order_bias: bool = True
feature_subset: tuple[str] | None = None
feature_subset: list[str] | None = None
population_filter: dict | None = None
seed: int = DEFAULT_SEED

Expand All @@ -53,6 +53,15 @@ def save_to_disk(self, path: str | Path):
"""Save the configuration to disk."""
save_json(dataclasses.asdict(self), path)

def __hash__(self) -> int:
cfg = dataclasses.asdict(self)
cfg["feature_subset"] = tuple(cfg["feature_subset"]) if cfg["feature_subset"] else None
cfg["population_filter_hash"] = (
hash_dict(cfg["population_filter"])
if cfg["population_filter"] else None
)
return int(hash_dict(cfg), 16)


class CalibrationBenchmark:
"""A benchmark class for measuring and evaluating LLM calibration."""
Expand Down Expand Up @@ -163,9 +172,10 @@ def run(self, fit_threshold: int | False = False) -> float:
s_test = self.dataset.get_sensitive_attribute_data().loc[y_test.index]

# Get LLM risk-estimate predictions for each row in the test set
test_predictions_save_path = self._get_predictions_save_path("test")
self._y_test_scores = self.llm_clf.predict_proba(
data=X_test,
predictions_save_path=self._get_predictions_save_path("test"),
predictions_save_path=test_predictions_save_path,
labels=y_test, # used only to save alongside predictions in disk
)

Expand All @@ -190,6 +200,9 @@ def run(self, fit_threshold: int | False = False) -> float:
model_name=self.llm_clf.model_name,
)

# Save predictions save path
self._results["predictions_path"] = test_predictions_save_path.as_posix()

# Log main results
msg = (
f"\n** Test results **\n"
Expand Down Expand Up @@ -305,7 +318,7 @@ def make_benchmark(
task = TaskMetadata.get_task(task)

if config.feature_subset is not None and len(config.feature_subset) > 0:
task = task.create_task_with_feature_subset(list(config.feature_subset))
task = task.create_task_with_feature_subset(config.feature_subset)
dataset.task = task

# Check dataset is compatible with task
Expand Down
9 changes: 6 additions & 3 deletions folktexts/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,9 +103,12 @@ def model_name(self) -> str:

@threshold.setter
def threshold(self, value: float) -> float:
assert 0 <= value <= 1, "Threshold must be between 0 and 1."
logging.debug(f"Setting threshold to {value}.")
self._threshold = value
if not 0 <= value <= 1:
logging.error(f"Threshold must be between 0 and 1; got {value}.")

# Clip threshold to valid range
self._threshold = np.clip(value, 0, 1)
logging.info(f"Set threshold to {self._threshold}.")

def __hash__(self) -> int:
"""Generate a unique hash for the LLMClassifier object."""
Expand Down
4 changes: 3 additions & 1 deletion folktexts/cli/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,9 @@ def _handle_str_value(val: str) -> int | float | str | bool:
for arg in cmdline_args:
parsed_arg = arg.lstrip("-").replace("-", "_")
if "=" in parsed_arg:
key, val = parsed_arg.split("=")
split_idx = parsed_arg.index("=")
key = parsed_arg[:split_idx]
val = parsed_arg[split_idx + 1:]
kwargs_dict[key] = _handle_str_value(val)
else:
kwargs_dict[parsed_arg] = True
Expand Down
4 changes: 2 additions & 2 deletions folktexts/cli/experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
DEFAULT_JOB_MEMORY_GB = 62 # GBs of memory
DEFAULT_GPU_MEMORY_GB = 30 # GBs of GPU memory

MAX_RUNNING_PRICE = 1000
MAX_RUNNING_PRICE = 1500 # Max price for running a job


@dataclass
Expand Down Expand Up @@ -97,7 +97,7 @@ def launch_experiment_job(exp: Experiment):

# Concurrency limits:
# > each job uses this amount of resources out of a pool of 10k
"concurrency_limits": "user.llm_clf:500", # 20 jobs in parallel
"concurrency_limits": "user.folktexts:100", # 100 jobs in parallel

"+MaxRunningPrice": MAX_RUNNING_PRICE,
"+RunningPriceExceededAction": classad.quote("restart"),
Expand Down
36 changes: 16 additions & 20 deletions folktexts/cli/launch_acs_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,14 @@
from pprint import pprint

from folktexts._io import load_json, save_json
from folktexts._utils import get_current_date
from folktexts.llm_utils import get_model_folder_path, get_model_size_B

from .experiments import Experiment, launch_experiment_job

# All ACS prediction tasks
ACS_TASKS = (
"ACSIncome",
# "ACSEmployment", # TODO: get other ACS tasks running
# "ACSEmployment", # TODO: run on other ACS tasks
# "ACSMobility",
# "ACSTravelTime",
# "ACSPublicCoverage",
Expand All @@ -26,14 +25,13 @@
# Useful paths #
################
ROOT_DIR = Path("/fast/groups/sf")
# ROOT_DIR = Path("/fast/acruz")
# ROOT_DIR = Path("~").expanduser().resolve() # on local machine

# ACS data directory
ACS_DATA_DIR = ROOT_DIR / "data"

# Directory to save results in (make sure it exists)
RESULTS_DIR = ROOT_DIR / "folktexts-results" / get_current_date()
RESULTS_DIR = ROOT_DIR / "folktexts-results"
RESULTS_DIR.mkdir(exist_ok=True, parents=False)

# Models save directory
Expand All @@ -49,14 +47,12 @@
BATCH_SIZE = 30
CONTEXT_SIZE = 500
CORRECT_ORDER_BIAS = True
FIT_THRESHOLD = 100

VERBOSE = True

JOB_CPUS = 4
JOB_MEMORY_GB = 60
# JOB_BID = 50
JOB_BID = 505
JOB_BID = 250

# LLMs to evaluate
LLM_MODELS = [
Expand All @@ -65,20 +61,20 @@
"google/gemma-1.1-2b-it",

# # ** Medium models **
# "google/gemma-7b",
# "google/gemma-1.1-7b-it",
# "mistralai/Mistral-7B-v0.1",
# "mistralai/Mistral-7B-Instruct-v0.2",
# "meta-llama/Meta-Llama-3-8B",
# "meta-llama/Meta-Llama-3-8B-Instruct",
"google/gemma-7b",
"google/gemma-1.1-7b-it",
"mistralai/Mistral-7B-v0.1",
"mistralai/Mistral-7B-Instruct-v0.2",
"meta-llama/Meta-Llama-3-8B",
"meta-llama/Meta-Llama-3-8B-Instruct",

# # ** Large models **
# "01-ai/Yi-34B",
# "01-ai/Yi-34B-Chat",
# "mistralai/Mixtral-8x7B-v0.1",
# "mistralai/Mixtral-8x7B-Instruct-v0.1",
# "meta-llama/Meta-Llama-3-70B",
# "meta-llama/Meta-Llama-3-70B-Instruct",
"01-ai/Yi-34B",
"01-ai/Yi-34B-Chat",
"mistralai/Mixtral-8x7B-v0.1",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"meta-llama/Meta-Llama-3-70B",
"meta-llama/Meta-Llama-3-70B-Instruct",
# "mistralai/Mixtral-8x22B-v0.1",
# "mistralai/Mixtral-8x22B-Instruct-v0.1",
]
Expand Down Expand Up @@ -115,7 +111,7 @@ def make_llm_as_clf_experiment(
experiment_kwargs.setdefault("batch_size", math.ceil(BATCH_SIZE / n_shots))
experiment_kwargs.setdefault("context_size", CONTEXT_SIZE * n_shots)
experiment_kwargs.setdefault("data_dir", ACS_DATA_DIR.as_posix())
experiment_kwargs.setdefault("fit_threshold", FIT_THRESHOLD)
# experiment_kwargs.setdefault("fit_threshold", FIT_THRESHOLD)

# Define experiment
exp = Experiment(
Expand Down
11 changes: 4 additions & 7 deletions folktexts/cli/run_acs_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
#!/usr/bin/env python3
"""Runs the LLM calibration benchmark from the command line.
"""
import logging
import json
import logging
import sys
from argparse import ArgumentParser
from pathlib import Path
Expand Down Expand Up @@ -119,11 +119,8 @@ def setup_arg_parser() -> ArgumentParser:
# Parse population filter if provided
population_filter_dict = None
if args.use_population_filter:
import ipdb; ipdb.set_trace() # TODO: debug
population_filter_dict = dict()
for filter_str in args.use_population_filter: # TODO: split by whitespace?
col_name, col_value = filter_str.split("=")
population_filter_dict[col_name] = col_value
from ._utils import cmd_line_args_to_kwargs
population_filter_dict = cmd_line_args_to_kwargs(args.use_population_filter)

# Load model and tokenizer
from folktexts.llm_utils import load_model_tokenizer
Expand All @@ -139,7 +136,7 @@ def setup_arg_parser() -> ArgumentParser:
batch_size=args.batch_size,
context_size=args.context_size,
correct_order_bias=not args.dont_correct_order_bias,
feature_subset=tuple(args.use_feature_subset) or None,
feature_subset=args.use_feature_subset or None,
population_filter=population_filter_dict,
seed=args.seed,
)
Expand Down
16 changes: 7 additions & 9 deletions folktexts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def __init__(
self._test_size = test_size
self._val_size = val_size or 0
self._train_size = 1 - self._test_size - self._val_size
self._subsampling = subsampling
assert self._train_size > 0

self._seed = seed
Expand All @@ -79,8 +78,9 @@ def __init__(
self._val_indices = None

# Subsample the train/test/val data (if requested)
if self._subsampling is not None:
self._subsample_inplace(self._subsampling)
self._subsampling = None
if subsampling is not None:
self._subsample_inplace(subsampling)

@property
def data(self) -> pd.DataFrame:
Expand Down Expand Up @@ -115,7 +115,7 @@ def val_size(self) -> float:

@property
def subsampling(self) -> float:
return self._subsampling
return getattr(self, "_subsampling", None)

@property
def seed(self) -> int:
Expand Down Expand Up @@ -155,7 +155,7 @@ def _subsample_inplace(self, subsampling: float) -> "Dataset":
self._val_indices = self._val_indices[: new_val_size]

# Update subsampling factor
self._subsampling = (self._subsampling or 1) * subsampling
self._subsampling = (getattr(self, "_subsampling", None) or 1) * subsampling

# Log new dataset size
msg = (
Expand All @@ -177,8 +177,6 @@ def _filter_inplace(
population_feature_values: dict,
) -> "Dataset":
"""Subset the dataset in-place: keep only samples with the given feature values."""
import ipdb; ipdb.set_trace()

# Check argument is of valid type
if not isinstance(population_feature_values, dict):
raise ValueError(
Expand All @@ -188,8 +186,8 @@ def _filter_inplace(
# Check argument keys are valid columns
if not all(key in self.data.columns for key in population_feature_values.keys()):
raise ValueError(
f"Invalid `population_feature_values` keys: "
f"{population_feature_values.keys()}.")
f"Invalid `population_feature_values` keys; columns don't exist "
f"in the dataset: {list(population_feature_values.keys())}.")

# Create boolean filter based on the given feature values
population_filter = pd.Series(True, index=self.data.index)
Expand Down
2 changes: 1 addition & 1 deletion folktexts/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ def evaluate_predictions(
results.update(evaluate_binary_predictions(y_true, y_pred_binary))

# Add loss functions as proxies for calibration
results["log_loss"] = log_loss(y_true, y_pred_scores)
results["log_loss"] = log_loss(y_true, y_pred_scores, labels=[0, 1])
results["brier_score_loss"] = brier_score_loss(y_true, y_pred_scores)

# Evaluate fairness metrics
Expand Down
2 changes: 1 addition & 1 deletion folktexts/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
"""
from __future__ import annotations

import logging
import dataclasses
import logging
from dataclasses import dataclass, field
from typing import Callable, ClassVar, Iterable

Expand Down
Loading

0 comments on commit 179e8ab

Please sign in to comment.