Skip to content

Commit

Permalink
improvements to is_local_dist_rank patching
Browse files Browse the repository at this point in the history
Signed-off-by: Yu Chin Fabian Lim <flim@sg.ibm.com>
  • Loading branch information
fabianlim committed Sep 27, 2024
1 parent 0258544 commit 3993b8c
Showing 1 changed file with 38 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -116,27 +116,55 @@ def model_loader(self, model_name: str, **kwargs):
except ValueError:
world_size = 1 # pg not init

patched_is_local_dist_rank_0 = None
if (
world_size > 1
and os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
):
config_kwargs["bnb_4bit_quant_storage"] = torch_dtype

# - of course assume that this package must exist, simply need the version
_, _transformers_version = _is_package_available("transformers", return_version=True)

# this is a workaround that disables low_cpu_mem_mode for quant QLORA
# - this issue was introduced in https://github.com/huggingface/transformers/pull/33154
# whereby the low_cpu_mem_mode was actually fixed.
# - However fixing it causes some problems with the current impl.
# 1. For lora fused ops, the adapters cannot be managed by FSDP, as
# forwards are not called. This causes issue 2) in
# https://github.com/foundation-model-stack/fms-acceleration/issues/83
# where the adapters are still sharded when passed in the fused-ops.
# However, if low_cpu_mem_mode=True, then we NEED FSDP to intialize
# their state, which contradicts the above point.
#
# 2. We have observed,
# see https://github.com/foundation-model-stack/fms-acceleration/pull/86
# that low_cpu_mem_mode=True can cause torch distributed primitives
# to hang.

if _transformers_version >= "4.45":

# pylint: disable=import-outside-toplevel
from fms_acceleration.model_patcher import patch_target_module
import transformers.modeling_utils

def _truthy():
return True
return True # use this to always return True to is_local_dist_rank_0

# - we cannot use the model patcher and this needs to be called immediately below
# at the model_loader
# - but we immediately revert the patch after loading
patched_is_local_dist_rank_0 = transformers.modeling_utils.is_local_dist_rank_0
patch_target_module(
"transformers.modeling_utils.is_local_dist_rank_0",
_truthy,
)

warnings.warn(
"Disabling low_cpu_mem_mode as this will cause problems with "
"the fused-ops-and-kernels package"
"Disabling low_cpu_mem_mode in the BNBAccelerationPlugin as this may "
"potentiall cause problems with: "
"1. the fused-ops-and-kernels package, and, "
"2. the syncing of FSDP modules across devices."
)

elif world_size > 1:
Expand Down Expand Up @@ -171,6 +199,13 @@ def _truthy():
attn_implementation=attn_implementation,
)

if patched_is_local_dist_rank_0 is not None:
# replace it
patch_target_module(
"transformers.modeling_utils.is_local_dist_rank_0",
patched_is_local_dist_rank_0,
)

return model

@property
Expand Down

0 comments on commit 3993b8c

Please sign in to comment.