Skip to content

Commit

Permalink
fmt and lint
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Sep 28, 2024
1 parent e87f351 commit 177d6d7
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -124,37 +124,44 @@ def model_loader(self, model_name: str, **kwargs):
config_kwargs["bnb_4bit_quant_storage"] = torch_dtype

# - of course assume that this package must exist, simply need the version
_, _transformers_version = _is_package_available("transformers", return_version=True)
_, _transformers_version = _is_package_available(
"transformers", return_version=True
)

# this is a workaround that disables low_cpu_mem_mode for quant QLORA
# - this issue was introduced in https://github.com/huggingface/transformers/pull/33154
# whereby the low_cpu_mem_mode was actually fixed.
# whereby the low_cpu_mem_mode was actually fixed.
# - However fixing it causes some problems with the current impl.
# 1. For lora fused ops, the adapters cannot be managed by FSDP, as
# forwards are not called. This causes issue 2) in
# 1. For lora fused ops, the adapters cannot be managed by FSDP, as
# forwards are not called. This causes issue 2) in
# https://github.com/foundation-model-stack/fms-acceleration/issues/83
# where the adapters are still sharded when passed in the fused-ops.
# However, if low_cpu_mem_mode=True, then we NEED FSDP to intialize
# their state, which contradicts the above point.
#
# 2. We have observed,
#
# 2. We have observed,
# see https://github.com/foundation-model-stack/fms-acceleration/pull/86
# that low_cpu_mem_mode=True can cause torch distributed primitives
# to hang.

if _transformers_version >= "4.45":

# pylint: disable=import-outside-toplevel
# Third Party
from fms_acceleration.model_patcher import patch_target_module
import transformers.modeling_utils

def _truthy():
return True # use this to always return True to is_local_dist_rank_0
return (
True # use this to always return True to is_local_dist_rank_0
)

# - we cannot use the model patcher and this needs to be called immediately below
# at the model_loader
# - but we immediately revert the patch after loading
patched_is_local_dist_rank_0 = transformers.modeling_utils.is_local_dist_rank_0
patched_is_local_dist_rank_0 = (
transformers.modeling_utils.is_local_dist_rank_0
)
patch_target_module(
"transformers.modeling_utils.is_local_dist_rank_0",
_truthy,
Expand Down
1 change: 1 addition & 0 deletions plugins/fused-ops-and-kernels/tests/test_foak_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
DIRNAME, "../configs/fast_quantized_peft.yaml"
)


@pytest.mark.skip(reason="Installation logic has changed - test to be fixed in future.")
def test_configure_gptq_foak_plugin():
"test foak plugin loads correctly"
Expand Down
2 changes: 1 addition & 1 deletion scripts/benchmarks/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,7 @@ def main(args):
help="accelerate config file path",
)
parser.add_argument(
"--process_port", type=int, default=29500, help="accelerate process port"
"--process_port", type=int, default=29511, help="accelerate process port"
)
parser.add_argument(
"--no_data_processing",
Expand Down

0 comments on commit 177d6d7

Please sign in to comment.