Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented remaining ACS columns and prediction tasks #1

Merged
merged 15 commits into from
Jun 24, 2024
Prev Previous commit
Next Next commit
fixed ACSDataset assignment of new task
  • Loading branch information
AndreFCruz committed Jun 24, 2024
commit 9ffbc7bbc0c0ea41f280bbf07f6d72f430650ea1
3 changes: 2 additions & 1 deletion folktexts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._version import __version__, __version_info__
from .acs import ACSDataset, ACSTaskMetadata
from .task import TaskMetadata
from .benchmark import BenchmarkConfig, CalibrationBenchmark
from .classifier import LLMClassifier
from .acs import ACSDataset, ACSTaskMetadata
8 changes: 4 additions & 4 deletions folktexts/acs/acs_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@
value_map=partial(
parse_pums_code,
file=ACS_OCCP_FILE,
postprocess=lambda x: x[4:].lower().strip(),
postprocess=lambda x: x[4:].lower().capitalize().strip(),
),
)

Expand All @@ -108,7 +108,7 @@
value_map=partial(
parse_pums_code,
file=ACS_POBP_FILE,
postprocess=lambda x: x[: x.find("/")].strip(),
postprocess=lambda x: (x[: x.find("/")] if "/" in x else x).strip(),
),
)

Expand Down Expand Up @@ -524,7 +524,7 @@
"PUMA",
short_description="Public Use Microdata Area (PUMA) code",
use_value_map_only=True,
value_map=lambda x: f"PUMA code: {int(x)}",
value_map=lambda x: f"PUMA code: {int(x)}.",
# missing_value_fill="N/A (less than 16 years old)",
)

Expand All @@ -533,7 +533,7 @@
"POWPUMA",
short_description="place of work PUMA",
use_value_map_only=True,
value_map=lambda x: f"Place of work PUMA code: {int(x)}",
value_map=lambda x: f"Place of work PUMA code: {int(x)}.",
# missing_value_fill="N/A (not a worker, or worker who worked at home)",
)

Expand Down
105 changes: 85 additions & 20 deletions folktexts/acs/acs_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,16 @@
import logging
from pathlib import Path

import pandas as pd
from folktables import ACSDataSource
from folktables.load_acs import state_list

from ..dataset import Dataset
from .acs_tasks import ACSTaskMetadata

DEFAULT_DATA_DIR = Path("~/data").expanduser().resolve()
DEFAULT_TEST_SIZE = 0.1
DEFAULT_VAL_SIZE = None
DEFAULT_SEED = 42

DEFAULT_SURVEY_YEAR = "2018"
Expand All @@ -24,6 +27,27 @@ class ACSDataset(Dataset):

def __init__(
self,
data: pd.DataFrame,
full_acs_data: pd.DataFrame,
task: ACSTaskMetadata,
test_size: float = DEFAULT_TEST_SIZE,
val_size: float = DEFAULT_VAL_SIZE,
subsampling: float = None,
seed: int = 42,
):
self._full_acs_data = full_acs_data
super().__init__(
data=data,
task=task,
test_size=test_size,
val_size=val_size,
subsampling=subsampling,
seed=seed,
)

@classmethod
def make_from_task(
cls,
task: str | ACSTaskMetadata,
cache_dir: str | Path = None,
survey_year: str = DEFAULT_SURVEY_YEAR,
Expand All @@ -32,7 +56,7 @@ def __init__(
seed: int = DEFAULT_SEED,
**kwargs,
):
"""Construct an ACSDataset object.
"""Construct an ACSDataset object using ACS survey parameters.

Parameters
----------
Expand All @@ -49,42 +73,83 @@ def __init__(
The name of the survey unit to load, by default DEFAULT_SURVEY_UNIT.
seed : int, optional
The random seed, by default DEFAULT_SEED.
**kwargs
Extra key-word arguments to be passed to the Dataset constructor.
"""
# Create "folktables" sub-folder under the given cache dir
cache_dir = Path(cache_dir or DEFAULT_DATA_DIR).expanduser().resolve() / "folktables"
if not cache_dir.exists():
logging.warning(f"Creating cache directory '{cache_dir}' for ACS data.")
cache_dir.mkdir(exist_ok=True, parents=False)

# Parse task if given a string
task_obj = ACSTaskMetadata.get_task(task) if isinstance(task, str) else task

# Load ACS data source
print("Loading ACS data...")
data_source = ACSDataSource(
survey_year=survey_year, horizon=horizon, survey=survey,
root_dir=cache_dir.as_posix(),
)

# Get ACS data in a pandas DF
data = data_source.get_data(
states=state_list, download=True, random_seed=seed,
# Get full ACS dataset
full_acs_data = data_source.get_data(
states=state_list, download=True, random_seed=seed)

# Parse data for this task
parsed_data = cls._parse_task_data(full_acs_data, task_obj)

return cls(
data=parsed_data,
full_acs_data=full_acs_data,
task=task,
seed=seed,
**kwargs,
)

# Get information on this ACS/folktables task
task_obj = ACSTaskMetadata.get_task(task) if isinstance(task, str) else task
@property
def task(self) -> ACSTaskMetadata:
return self._task

@task.setter
def task(self, new_task: ACSTaskMetadata):
# Parse data rows for new ACS task
self._data = self._parse_task_data(self._full_acs_data, new_task)

# Keep only rows used in this task
if isinstance(task_obj, ACSTaskMetadata) and task_obj.folktables_obj is not None:
data = task_obj.folktables_obj._preprocess(data)
# Check if task columns are in the data
if not all(col in self.data.columns for col in (new_task.features + [new_task.get_target()])):
raise ValueError(
f"Task columns not found in dataset: "
f"features={new_task.features}, target={new_task.get_target()}")

self._task = new_task

@classmethod
def _parse_task_data(cls, full_df: pd.DataFrame, task: ACSTaskMetadata) -> pd.DataFrame:
"""Parse a DataFrame for compatibility with the given task object.

Parameters
----------
full_df : pd.DataFrame
Full DataFrame. Some rows and/or columns may be discarded for each
task.
task : ACSTaskMetadata
The task object used to parse the given data.

Returns
-------
parsed_df : pd.DataFrame
Parsed DataFrame in accordance with the given task.
"""
if not isinstance(task, ACSTaskMetadata):
logging.error(f"Expected task of type `ACSTaskMetadata` for {type(task)}")
return full_df

# Parse data
parsed_df = task.folktables_obj._preprocess(full_df)

# Threshold the target column if necessary
# > use standardized ACS naming convention
if task_obj.target_threshold is not None:
thresholded_target = task_obj.get_target()
if thresholded_target not in data.columns:
data[thresholded_target] = task_obj.target_threshold.apply_to_column_data(data[task_obj.target])
if task.target_threshold is not None and task.get_target() not in parsed_df.columns:
parsed_df[task.get_target()] = task.target_threshold.apply_to_column_data(parsed_df[task.target])

super().__init__(
data=data,
task=task_obj,
seed=seed,
**kwargs,
)
return parsed_df
2 changes: 1 addition & 1 deletion folktexts/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -383,7 +383,7 @@ def make_acs_benchmark(

# Fetch ACS task and dataset
acs_task = ACSTaskMetadata.get_task(task_name)
acs_dataset = ACSDataset(
acs_dataset = ACSDataset.make_from_task(
task_obj=acs_task,
cache_dir=data_dir,
**acs_dataset_configs)
Expand Down
48 changes: 14 additions & 34 deletions folktexts/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,30 +9,29 @@
"""
from __future__ import annotations

import copy
import logging
import warnings
from abc import ABC

import numpy as np
import pandas as pd

from ._utils import hash_dict, is_valid_number, suppress_logging
from ._utils import hash_dict, is_valid_number
from .task import TaskMetadata

DEFAULT_TEST_SIZE = 0.1
DEFAULT_VAL_SIZE = None
DEFAULT_SEED = 42


class Dataset(ABC):
def __init__(
self,
data: pd.DataFrame,
task: TaskMetadata, # TODO: remove this from the Dataset
task: TaskMetadata,
test_size: float = DEFAULT_TEST_SIZE,
val_size: float = DEFAULT_VAL_SIZE,
subsampling: float = None,
seed: int = 42,
seed: int = DEFAULT_SEED,
):
"""Construct a Dataset object.

Expand Down Expand Up @@ -92,15 +91,14 @@ def task(self) -> TaskMetadata:
return self._task

@task.setter
def task(self, task: TaskMetadata):
logging.info(f"Updating dataset's task from '{self.task.name}' to '{task.name}'.")
def task(self, new_task: TaskMetadata):
# Check if task columns are in the data
if not all(col in self.data.columns for col in (task.features + [task.get_target()])):
if not all(col in self.data.columns for col in (new_task.features + [new_task.get_target()])):
raise ValueError(
f"Task columns not found in dataset: "
f"features={task.features}, target={task.get_target()}")
f"features={new_task.features}, target={new_task.get_target()}")

self._task = task
self._task = new_task

@property
def train_size(self) -> float:
Expand Down Expand Up @@ -130,22 +128,6 @@ def name(self) -> str:
hash_str = f"hash-{hash(self)}"
return f"{self.task.name}_{subsampling_str}_{seed_str}_{hash_str}"

def __copy__(self) -> "Dataset":
dataset = Dataset(
data=self.data,
task=self.task,
test_size=self.test_size,
val_size=self.val_size,
subsampling=self.subsampling,
seed=self.seed,
)
dataset._train_indices = self._train_indices.copy()
dataset._test_indices = self._test_indices.copy()
dataset._val_indices = self._val_indices.copy() if self._val_indices is not None else None
dataset._rng = copy.deepcopy(self._rng)

return dataset

def _subsample_inplace(self, subsampling: float) -> "Dataset":
"""Subsample the dataset in-place."""

Expand Down Expand Up @@ -177,11 +159,9 @@ def _subsample_inplace(self, subsampling: float) -> "Dataset":

return self

def subsample(self, subsampling: float) -> "Dataset":
"""Create a new dataset whose samples are a fraction of this dataset."""
with suppress_logging(logging.WARNING):
self_copy = copy.copy(self)
return self_copy._subsample_inplace(subsampling)
def subsample(self, subsampling: float):
"""Subsamples this dataset in-place."""
return self._subsample_inplace(subsampling)

def _filter_inplace(
self,
Expand Down Expand Up @@ -216,9 +196,9 @@ def _filter_inplace(

return self

def filter(self, population_feature_values: dict) -> "Dataset":
"""Create a new dataset whose samples are a subset of this dataset."""
return copy.copy(self)._filter_inplace(population_feature_values)
def filter(self, population_feature_values: dict):
"""Filter dataset rows in-place."""
self._filter_inplace(population_feature_values)

def get_features_data(self) -> pd.DataFrame:
return self.data[self.task.features]
Expand Down
6 changes: 3 additions & 3 deletions folktexts/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,9 @@ def evaluate_binary_predictions_fairness(
def group_metric_name(metric_name, group_name):
return f"{metric_name}_group={group_name}"

assert (
len(unique_groups) > 1
), f"Found a single unique sensitive attribute: {unique_groups}"
if len(unique_groups) <= 1:
logging.error(f"Found a single unique sensitive attribute: {unique_groups}")
return {}

for s_value in unique_groups:
# Indices of samples that belong to the current group
Expand Down
2 changes: 1 addition & 1 deletion folktexts/qa_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ def create_answer_keys_permutations(cls, question: "MultipleChoiceQA") -> Iterat
yield dataclasses.replace(question, choices=perm)

@property
def answer_keys(self) -> list[str]:
def answer_keys(self) -> tuple[str]:
return self._answer_keys_source[:len(self.choices)]

@property
Expand Down
Loading