Skip to content

Commit

Permalink
Merge branch 'tensor_compression' of github.com:neuralmagic/sparseml …
Browse files Browse the repository at this point in the history
…into tensor_compression
  • Loading branch information
Sara Adkins committed Mar 15, 2024
2 parents 013d17b + 1749b28 commit 813c8e7
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 22 deletions.
58 changes: 58 additions & 0 deletions .github/workflows/build-wheel.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
name: Build PyPi Wheel
on:
pull_request:
types: [opened, synchronize, reopened]
branches:
- main
- 'release/[0-9]+.[0-9]+'
push:
branches:
- main
release:
types: [created, published]
schedule:
- cron: '0 0 * * *'

permissions:
id-token: write
contents: read

concurrency:
group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true

# if not dev or release, will create a nightly build
# everything is pushed to internal unless created through a nightly scheduled cron job which creates the build or
# missing release tag workflow/needs to be added in
env:
INTERNAL: ${{ github.event_name != 'schedule' && github.event_name != 'release'}}
RELEASE: ${{ github.event_name =='release' || (startsWith(github.base_ref, 'release/') && github.event_name == 'pull_request')}}
DEV: ${{ github.base_ref == 'main' && github.event_name == 'pull_request'}}
NAME: ${{ github.event.number }}

jobs:
build_and_push:
runs-on: ubuntu-latest
outputs:
wheel: ${{ steps.push-wheel.outputs.wheel }}
steps:
- name: Checkout code
uses: actions/checkout@v3
- name: Login to s3
uses: aws-actions/configure-aws-credentials@v2
with:
role-to-assume: ${{ secrets.AWS_WEBIDENTITY_FOR_GITHUB_ACTIONS }}
aws-region: us-east-1
- name: Build PyPi Wheel
id: build-wheel
uses: neuralmagic/nm-actions/actions/pypi_build@main
with:
dev: $DEV
release: $RELEASE
name: $NAME
- name: Push to s3 bucket
id: push-wheel
uses: neuralmagic/nm-actions/actions/s3_push@main
with:
filename: dist/*.whl
internal: $INTERNAL
8 changes: 7 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

# default variables to be overwritten by the version.py file
is_release = None
is_dev = None
version = "unknown"
version_major_minor = version

Expand All @@ -28,7 +29,12 @@
print(f"loaded version {version} from src/sparseml/version.py")
version_nm_deps = f"{version_major_minor}.0"

_PACKAGE_NAME = "sparseml" if is_release else "sparseml-nightly"
if is_release:
_PACKAGE_NAME = "sparseml"
elif is_dev:
_PACKAGE_NAME = "sparseml-dev"
else:
_PACKAGE_NAME = "sparseml-nightly"

_deps = [
"setuptools<=59.5.0",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,6 @@ def lm_eval_harness(
:param kwargs: additional keyword arguments to pass to the
lm-evaluation-harness. For example, `limit`
"""

kwargs["limit"] = int(limit) if (limit := kwargs.get("limit")) else None

tokenizer = SparseAutoTokenizer.from_pretrained(model_path)
Expand Down
3 changes: 3 additions & 0 deletions src/sparseml/modifiers/pruning/constant/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ def on_update(self, state: State, event: Event, **kwargs):
def apply_masks(module):
mask_name = param_mask_name()
if hasattr(module, mask_name):
mask = getattr(module, mask_name)
if mask.device != module.weight.device:
setattr(module, mask_name, mask.to(module.weight.device))
module.weight *= getattr(module, mask_name)

state.model.model.apply(apply_masks)
Expand Down
19 changes: 18 additions & 1 deletion src/sparseml/transformers/finetune/data/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
from datasets.dataset_dict import Dataset, DatasetDict

from sparseml.transformers.finetune.data import TextGenerationDataset
from sparseml.transformers.utils.preprocessing_functions import (
PreprocessingFunctionRegistry,
)
from sparsezoo.utils.helpers import import_from_path


@TextGenerationDataset.register(name="custom", alias=["json", "csv"])
Expand Down Expand Up @@ -55,9 +59,21 @@ def get_raw_dataset(self, *_ignore, **__ignore) -> Union[DatasetDict, Dataset]:
raw_dataset = super().get_raw_dataset()

if self.preprocessing_func is not None:

if callable(self.preprocessing_func):
func = self.preprocessing_func
elif ":" in self.preprocessing_func:
# load func_name from "/path/to/file.py:func_name"
func = import_from_path(self.preprocessing_func)
else:
# load from the registry
func = PreprocessingFunctionRegistry.get_value_from_registry(
name=self.preprocessing_func
)

raw_dataset = self.map(
raw_dataset,
function=self.preprocessing_func,
function=func,
batched=False,
num_proc=self.data_args.preprocessing_num_workers,
desc="Applying custom func to the custom dataset",
Expand All @@ -82,6 +98,7 @@ def get_remove_columns_from_dataset(
self, raw_dataset: Union[DatasetDict, Dataset]
) -> List[str]:
"""Remove redandant columns from the dataset for processing"""

remove_columns = raw_dataset.column_names
if isinstance(remove_columns, Dict):
remove_columns = raw_dataset[list(raw_dataset.keys())[0]].column_names
Expand Down
11 changes: 9 additions & 2 deletions src/sparseml/transformers/finetune/data/data_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,15 @@ class CustomDataTrainingArguments(DVCDatasetTrainingArguments):
metadata={"help": "Column names to remove after preprocessing custom datasets"},
)

preprocessing_func: Optional[Callable] = field(
default=None, metadata={"help": "The preprcessing function to apply"}
preprocessing_func: Union[None, str, Callable] = field(
default=None,
metadata={
"help": (
"The preprocessing function to apply ",
"or the preprocessing func name in "
"src/sparseml/transformers/utils/preprocessing_functions.py",
)
},
)


Expand Down
33 changes: 33 additions & 0 deletions src/sparseml/transformers/finetune/data/data_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def get_raw_dataset(
:return: the requested dataset
"""

raw_datasets = load_dataset(
data_args.dataset,
data_args.dataset_config_name,
Expand Down Expand Up @@ -125,6 +126,7 @@ def make_dataset_splits(
tokenized_datasets = {"train": tokenized_datasets}

train_split = eval_split = predict_split = calib_split = None

if do_train:
if "train" not in tokenized_datasets:
raise ValueError("--do_train requires a train dataset")
Expand Down Expand Up @@ -218,4 +220,35 @@ def get_custom_datasets_from_path(path: str, ext: str = "json") -> Dict[str, str
if dir_dataset:
data_files[dir_name] = dir_dataset

return transform_dataset_keys(data_files)


def transform_dataset_keys(data_files: Dict[str, Any]):
"""
Transform dict keys to `train`, `val` or `test` for the given input dict
if matches exist with the existing keys. Note that there can only be one
matching file name.
Ex. Folder(train_eval.json) -> Folder(train.json)
Folder(train1.json, train2.json) -> Same
:param data_files: The dict where keys will be transformed
"""
keys = set(data_files.keys())

def transform_dataset_key(candidate: str) -> None:
for key in keys:
if candidate in key:
if key == candidate:
return
val = data_files.pop(key)
data_files[candidate] = val

def do_transform(candidate: str) -> bool:
return sum(candidate in key for key in keys) == 1

dataset_keys = ("train", "val", "test")
for dataset_key in dataset_keys:
if do_transform(dataset_key):
transform_dataset_key(dataset_key)

return data_files
8 changes: 0 additions & 8 deletions src/sparseml/transformers/finetune/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,7 @@
)
from sparseml.transformers.finetune.model_args import ModelArguments
from sparseml.transformers.finetune.training_args import TrainingArguments
from sparseml.utils.fsdp.context import summon_full_params_context
from sparseml.utils.fsdp.helpers import is_fsdp_model, unwrap_and_export_model
from sparseml.utils.pytorch import qat_active


_LOGGER: logging.Logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -288,12 +286,6 @@ def run_sequential_stages(self, checkpoint: Optional[str] = None):
session = session_manager.active_session()
session.reset_stage()

# log model sparsity
with summon_full_params_context(self.trainer.model):
if self.trainer.accelerator.is_main_process:
if not qat_active(self.trainer.model):
self.trainer.log_model_sparsification()

# synchronize and clean up memory
self.trainer.accelerator.wait_for_everyone()
self.trainer.model = get_session_model()
Expand Down
19 changes: 15 additions & 4 deletions src/sparseml/transformers/finetune/session_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
)
from sparseml.utils.fsdp.context import summon_full_params_context
from sparseml.utils.fsdp.helpers import is_fsdp_model, save_pretrained_fsdp
from sparseml.utils.pytorch import qat_active


__all__ = [
Expand Down Expand Up @@ -137,7 +138,7 @@ def initialize_session(
train_data = self.get_train_dataloader()

self.accelerator.wait_for_everyone()
with summon_full_params_context(self.model):
with summon_full_params_context(self.model, offload_to_cpu=True):
session_manager.initialize(
model=self.model,
teacher_model=self.teacher, # TODO: what about for self/disable?
Expand Down Expand Up @@ -370,9 +371,13 @@ def train(self, *args, stage: Optional[str] = None, **kwargs):

self.accelerator.wait_for_everyone()

# Need to gather parameters across the GPUs before accessing layer weights
with summon_full_params_context(self.model):
self.log_model_sparsification()
# log model sparsity
with summon_full_params_context(self.model, offload_to_cpu=True):
if self.accelerator.is_main_process:
if not qat_active(self.model):
self.log_model_sparsification()

self.accelerator.wait_for_everyone()

return output

Expand Down Expand Up @@ -434,6 +439,12 @@ def one_shot(self, calib_data: DataLoader, stage: Optional[str] = None):
accelerator=self.accelerator,
)

# log model sparsity
with summon_full_params_context(self.model, offload_to_cpu=True):
if self.accelerator.is_main_process:
if not qat_active(self.model):
self.log_model_sparsification()

self.accelerator.wait_for_everyone()

def save_model(
Expand Down
1 change: 1 addition & 0 deletions src/sparseml/transformers/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from .helpers import *
from .load_task_dataset import *
from .metrics import *
from .preprocessing_functions import *
from .sparse_config import *
from .sparse_model import *
from .sparse_tokenizer import *
29 changes: 29 additions & 0 deletions src/sparseml/transformers/utils/preprocessing_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict

from sparsezoo.utils.registry import RegistryMixin


class PreprocessingFunctionRegistry(RegistryMixin):
...


@PreprocessingFunctionRegistry.register()
def custom_evolved_codealpaca_dataset(data: Dict):
PROMPT_DICT = """[Instruction]:\n{instruction}\n\n[Response]:"""
data["prompt"] = PROMPT_DICT.format_map(data)
data["text"] = data["prompt"] + data["output"]
return data
13 changes: 8 additions & 5 deletions src/sparseml/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,17 @@

version_base = "1.7.0"
is_release = False # change to True to set the generated version as a release version
is_dev = False
dev_number = None


def _generate_version():
return (
version_base
if is_release
else f"{version_base}.{date.today().strftime('%Y%m%d')}"
)
if is_release:
return version_base
elif is_dev:
return f"{version_base}.dev{dev_number}"
else:
return f"{version_base}.{date.today().strftime('%Y%m%d')}"


__all__ = [
Expand Down

0 comments on commit 813c8e7

Please sign in to comment.