Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,6 @@ cython_debug/

checkpoints/
wandb/
statistics/
results/
images/
14 changes: 8 additions & 6 deletions sparsify/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .data import MemmapDataset, chunk_and_tokenize
from .trainer import TrainConfig, Trainer
from .utils import simple_parse_args_string


@dataclass
Expand Down Expand Up @@ -69,6 +70,11 @@ class RunConfig(TrainConfig):
)
"""Number of processes to use for preprocessing data"""

data_args: str = field(
default="",
)
"""Arguments to pass to the HuggingFace dataset constructor."""


def load_artifacts(
args: RunConfig, rank: int
Expand Down Expand Up @@ -101,12 +107,8 @@ def load_artifacts(
else:
# For Huggingface datasets
try:
dataset = load_dataset(
args.dataset,
split=args.split,
# TODO: Maybe set this to False by default? But RPJ requires it.
trust_remote_code=True,
)
kwargs = simple_parse_args_string(args.data_args)
dataset = load_dataset(args.dataset, split=args.split, **kwargs)
except ValueError as e:
# Automatically use load_from_disk if appropriate
if "load_from_disk" in str(e):
Expand Down
30 changes: 30 additions & 0 deletions sparsify/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,3 +101,33 @@ def triton_decode(top_indices: Tensor, top_acts: Tensor, W_dec: Tensor):
decoder_impl = eager_decode
else:
decoder_impl = triton_decode


def handle_arg_string(arg):
if arg.lower() == "true":
return True
elif arg.lower() == "false":
return False
elif arg.isnumeric():
return int(arg)
try:
return float(arg)
except ValueError:
return arg


def simple_parse_args_string(args_string: str) -> dict:
"""
Parses something like
args1=val1,arg2=val2
into a dictionary.
"""
args_string = args_string.strip()
if not args_string:
return {}
arg_list = [arg for arg in args_string.split(",") if arg]
args_dict = {
kv[0]: handle_arg_string("=".join(kv[1:]))
for kv in [arg.split("=") for arg in arg_list]
}
return args_dict