Skip to content

Commit

Permalink
pass module to Params4bit.from_prequantized to ensure quant_state (hu…
Browse files Browse the repository at this point in the history
…ggingface#32524)

* pass module to Params4bit.from_prequantized to ensure quant_state

* make sure to check bnb version

* revert min bnb version and use inspect on method instead

* use version instead of inspect to prevent performance hit

* make the property name readable
  • Loading branch information
winglian authored and zucchini-nlp committed Aug 30, 2024
1 parent 87bfa1d commit 32b79fb
Showing 1 changed file with 15 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
from functools import cached_property
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union

from packaging import version
Expand Down Expand Up @@ -207,11 +208,16 @@ def create_quantized_param(
if unexpected_keys is not None and k in unexpected_keys:
unexpected_keys.remove(k)

param_kwargs = {}
if self.is_bnb_supports_quant_storage_module:
param_kwargs["module"] = module

new_value = bnb.nn.Params4bit.from_prequantized(
data=param_value,
quantized_stats=quantized_stats,
requires_grad=False,
device=target_device,
**param_kwargs,
)
else:
new_value = param_value.to("cpu")
Expand Down Expand Up @@ -318,6 +324,15 @@ def is_serializable(self):

return True

@cached_property
def is_bnb_supports_quant_storage_module(self) -> bool:
"""
determines if the current version of bitsandbytes supports
the `module` parameter in `Params4bit.from_prequantized`
:return:
"""
return version.parse(importlib.metadata.version("bitsandbytes")) >= version.parse("0.43.3")

@property
def is_trainable(self) -> bool:
return True
Expand Down

0 comments on commit 32b79fb

Please sign in to comment.