Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
UmerHA committed Apr 27, 2024
1 parent 467933f commit 078dd4f
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion train.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def fsdp_main(local_rank:int, world_size:int, args:Dict):
model = AutoModelForCausalLM.from_config(cfg)
if args["train_type"] in ["hqq_lora", "hqq_dora", "hqq_llama_pro"]:
# TODO: Tune BaseQuantizeConfig.
quant_config = BaseQuantizeConfig(nbits=4, group_size=64, quant_zero=True,
quant_config = BaseQuantizeConfig(nbits=int(args["n_bits"]), group_size=64, quant_zero=True,
quant_scale=True, offload_meta=True, view_as_float=True)
model.model = replace_linear(model.model, HQQLinear, quant_config, device=rank,
compute_dtype=compute_dtype, del_orig=True, initialize=False, skip_modules=skip_modules)
Expand Down Expand Up @@ -1032,6 +1032,8 @@ def main(
name: str = None, # For wandb logging
group: str = None, # For wandb logging
entity: str = None, # For wandb logging
# ---- added by Umer
n_bits: int = 4, # passed to hqq
):

# Set world size
Expand Down

0 comments on commit 078dd4f

Please sign in to comment.